Maximum Side Length of a Square with Sum Less than or Equal to Threshold

Given a m x n matrix mat and an integer threshold. Return the maximum side-length of a square with a sum less than or equal to threshold or return 0 if there is no such square.

 

Example 1:

Input: mat = [[1,1,3,2,4,3,2],[1,1,3,2,4,3,2],[1,1,3,2,4,3,2]], threshold = 4
Output: 2
Explanation: The maximum side length of square with sum less than 4 is 2 as shown.

Example 2:

Input: mat = [[2,2,2,2,2],[2,2,2,2,2],[2,2,2,2,2],[2,2,2,2,2],[2,2,2,2,2]], threshold = 1
Output: 0

Example 3:

Input: mat = [[1,1,1,1],[1,0,0,0],[1,0,0,0],[1,0,0,0]], threshold = 6
Output: 3

Example 4:

Input: mat = [[18,70],[61,1],[25,85],[14,40],[11,96],[97,96],[63,45]], threshold = 40184
Output: 2

 

Constraints:


Solution:

class Solution {
    public int maxSideLength(int[][] mat, int threshold) {
        int m = mat.length, n = mat[0].length;
        if (m > n) {
            int[][] ret = new int[n][m];
            for (int i = 0; i < m; i++) {
                for (int j = 0; j < n; j++) {
                    ret[j][i] = mat[i][j];
                }
            }
            return maxSideLength(ret, threshold);
        }
        int[][] pre = new int[m][n];
        for (int i = m - 1; i >= 0; i --) {
            for (int j = n - 1; j >= 0; j --) {
                pre[i][j] += mat[i][j];
                if (i + 1 < m) {
                    pre[i][j] += pre[i + 1][j];
                }
                if (j + 1 < n) {
                    pre[i][j] += pre[i][j + 1];
                }
                if (i + 1 < m && j + 1 < n) {
                    pre[i][j] -= pre[i + 1][j + 1];
                }
            }
        }
        // for (int[] arr : pre) {
        //     System.out.println(Arrays.toString(arr));
        // }
        for (int len = m; len > 0; len --) {
            for (int i = 0; i + len <= m; i ++) {
                for (int j = 0; j + len <= n; j ++) {
                    int sum = pre[i][j];
                    if (j + len < n) {
                        sum -= pre[i][j + len];
                    }
                    if (i + len < m) {
                        sum -= pre[i + len][j];
                    }
                    if (j + len < n && i + len < m) {
                        sum += pre[i + len][j + len];
                    }
                    if (sum <= threshold) {
                        // System.out.println(i + ", " + j);
                        // System.out.println(sum);
                        return len;
                    }
                }   
            }
        }
        return 0;
    }
}