Sub Matrices with sum Zero

Given a 2D matrix, find the number non-empty sub matrices, such that the sum of the elements inside the sub matrix is equal to 0. (note: elements might be negative).

Example:

Input

-8 5  7
3  7 -8
5 -8  9

Output
2

Explanation
-8 5 7
3 7 -8
5 -8 9

-8 5 7
3 7 -8
5 -8 9

Solution:

Time: O(n^4)
Space: O(n^2)

public class Solution {
    private int subArraySum(int[] arr, int target) {
        int n = arr.length;
        // dp[i][j] = sub array sum from i to j
        // dp[i][j] = dp[i][j - 1] + A[j]
        // dp[i][i] = A[i]
        // dp[0][n - 1]
        int[][] dp = new int[n][n];
        int result = 0;
        for (int len = 0; len < n; len ++) {
            for (int i = 0; i < n; i ++) {
                int j = i + len;
                if (j >= n) continue;
                if (i == j) {
                    dp[i][i] = arr[j];
                } else {
                    dp[i][j] = dp[i][j - 1] + arr[j];
                }
                if (dp[i][j] == target) {
                    result ++;
                }
            }
        }
        return result;
    }
    
    private int[][] prefixSum(int[][] A) {
        int m = A.length;
        int n = A[0].length;
        int[][] prefix = new int[m][n];
        for (int i = 0; i < m; i ++) {
            for (int j = 0; j < n; j ++) {
                if (i == 0) {
                    prefix[i][j] = A[i][j];
                } else {
                    prefix[i][j] = prefix[i - 1][j] + A[i][j];
                }
            }
        }
        return prefix;
    }
    
    public int solve(int[][] A) {
        int m = A.length;
        if (m == 0) return 0;
        int n = A[0].length;
        if (n == 0) return 0;
        int[][] prefix = prefixSum(A);
        int result = 0;
        for (int len = 0; len < m; len ++) {
            for (int i = 0; i < m; i ++) {
                int j = i + len;
                if (i == j) {
                    result += subArraySum(A[i], 0);
                } else if (j < m) {
                    int[] sub = new int[n];
                    for (int p = 0; p < n; p ++) {
                        sub[p] = prefix[j][p] - prefix[i][p] + A[i][p];
                    }
                    result += subArraySum(sub, 0);
                }
            }   
        }
        return result;
    }
}