Maximum Non Negative Product in a Matrix

You are given a rows x cols matrix grid. Initially, you are located at the top-left corner (0, 0), and in each step, you can only move right or down in the matrix.

Among all possible paths starting from the top-left corner (0, 0) and ending in the bottom-right corner (rows - 1, cols - 1), find the path with the maximum non-negative product. The product of a path is the product of all integers in the grid cells visited along the path.

Return the maximum non-negative product modulo 109 + 7. If the maximum product is negative return -1.

Notice that the modulo is performed after getting the maximum product.

 

Example 1:

Input: grid = [[-1,-2,-3],
               [-2,-3,-3],
               [-3,-3,-2]]
Output: -1
Explanation: It's not possible to get non-negative product in the path from (0, 0) to (2, 2), so return -1.

Example 2:

Input: grid = [[1,-2,1],
               [1,-2,1],
               [3,-4,1]]
Output: 8
Explanation: Maximum non-negative product is in bold (1 * 1 * -2 * -4 * 1 = 8).

Example 3:

Input: grid = [[1, 3],
               [0,-4]]
Output: 0
Explanation: Maximum non-negative product is in bold (1 * 0 * -4 = 0).

Example 4:

Input: grid = [[ 1, 4,4,0],
               [-2, 0,0,1],
               [ 1,-1,1,1]]
Output: 2
Explanation: Maximum non-negative product is in bold (1 * -2 * 1 * -1 * 1 * 1 = 2).

 

Constraints:


Solution:

class Solution {
    public int maxProductPath(int[][] grid) {
        int m = grid.length;
        int n = grid[0].length;
        long[][] pos = new long[m][n];
        long[][] neg = new long[m][n];
        int mod = (int) 1e9 + 7;
        pos[0][0] = grid[0][0] > 0 ? grid[0][0] : 0;
        neg[0][0] = grid[0][0] < 0 ? grid[0][0] : 0;
        boolean zero = grid[0][0] == 0;
        for (int i = 1; i < m; i ++) {
            int j = 0;
            int curr = grid[i][j];
            pos[i][j] = 0;
            neg[i][j] = 0;
            if (curr > 0) {
                if (pos[i - 1][j] > 0) {
                    pos[i][j] = Math.max(pos[i][j], curr * pos[i - 1][j]);
                }
                if (neg[i - 1][j] < 0) {
                    neg[i][j] = Math.min(neg[i][j], curr * neg[i - 1][j]);
                }
            } else if (curr < 0) {
                if (neg[i - 1][j] < 0) {
                    pos[i][j] = Math.max(pos[i][j], curr * neg[i - 1][j]);
                }
                if (pos[i - 1][j] > 0) {
                    neg[i][j] = Math.min(neg[i][j], curr * pos[i - 1][j]);
                }
            } else {
                zero = true;
            }
        }
        for (int j = 1; j < n; j ++) {
            int i = 0;
            int curr = grid[i][j];
            pos[i][j] = 0;
            neg[i][j] = 0;
            if (curr > 0) {
                if (pos[i][j - 1] > 0) {
                    pos[i][j] = Math.max(pos[i][j], curr * pos[i][j - 1]);
                }
                if (neg[i][j - 1] < 0) {
                    neg[i][j] = Math.min(neg[i][j], curr * neg[i][j - 1]);
                }
            } else if (curr < 0) {
                if (neg[i][j - 1] < 0) {
                    pos[i][j] = Math.max(pos[i][j], curr * neg[i][j - 1]);
                }
                if (pos[i][j - 1] > 0) {
                    neg[i][j] = Math.min(neg[i][j], curr * pos[i][j - 1]);
                }
            } else {
                zero = true;
            }
        }
        for (int i = 1; i < m; i ++) {
            for (int j = 1; j < n; j ++) {
                int curr = grid[i][j];
                pos[i][j] = 0;
                neg[i][j] = 0;
                if (curr > 0) {
                    if (pos[i][j - 1] > 0) {
                        pos[i][j] = Math.max(pos[i][j], curr * pos[i][j - 1]);
                    }
                    if (pos[i - 1][j] > 0) {
                        pos[i][j] = Math.max(pos[i][j], curr * pos[i - 1][j]);
                    }
                    if (neg[i][j - 1] < 0) {
                        neg[i][j] = Math.min(neg[i][j], curr * neg[i][j - 1]);
                    }
                    if (neg[i - 1][j] < 0) {
                        neg[i][j] = Math.min(neg[i][j], curr * neg[i - 1][j]);
                    }
                } else if (curr < 0) {
                    if (neg[i][j - 1] < 0) {
                        pos[i][j] = Math.max(pos[i][j], curr * neg[i][j - 1]);
                    }
                    if (neg[i - 1][j] < 0) {
                        pos[i][j] = Math.max(pos[i][j], curr * neg[i - 1][j]);
                    }
                    if (pos[i][j - 1] > 0) {
                        neg[i][j] = Math.min(neg[i][j], curr * pos[i][j - 1]);
                    }
                    if (pos[i - 1][j] > 0) {
                        neg[i][j] = Math.min(neg[i][j], curr * pos[i - 1][j]);
                    }
                } else {
                    zero = true;
                }
            }
        }
        // for (long[] p : pos) System.out.println(Arrays.toString(p));
        // System.out.println();
        // for (long[] nn : neg) System.out.println(Arrays.toString(nn));
        int res = (int) (pos[m - 1][n - 1] % mod);
        if (res > 0) return res;
        if (zero) return 0;
        return -1;
    }
}