Matrix Median

Given a N cross M matrix in which each row is sorted, find the overall median of the matrix. Assume N*M is odd.

For example,

Matrix=
[1, 3, 5]
[2, 6, 9]
[3, 6, 9]

A = [1, 2, 3, 3, 5, 6, 6, 9, 9]

Median is 5. So, we return 5.

Note: No extra memory is allowed.

思路:

For an array, Define a function F(x) = number of elements smaller than or equal to X
We want to find the smallest number x such that F(x) >= (N * M + 1) / 2.
In the example,  
[1, 2, 3, 3, 5, 6, 6, 9, 9] 
We want to find the smallest number X whose F(X) >= 5. We can see that F(3) = 4, F(4) = 4, F(5) = 5, hence we got 5 as the median.

Therefore we can do a binary search from the range 0 to IntMax, (or we can find the min/max in the matrix first, but it doesn't change the complexity much), and find the desired X.

Note that when we find such X, it is guarantee to be in the matrix because F(X - 1) < F(X), otherwise we would return X - 1. And F(X - 1) != F(X) tells us X is in the matrix because only when X is in the matrix/array, F(X - 1) != F(X) stands.

Again, for example,
[1, 2, 3, 3, 5, 6, 6, 9, 9] 
 F(3) = 4, F(4) = 4, F(5) = 5
we see F(2) != F(3) = F(4) != F(5), this is because if X is not array, we can always find F(x - 1) = F(x).

Solution:

Time: O(32 * mlogn)
Space: O(1)

public class Solution {
    public int findMedian(ArrayList<ArrayList<Integer>> A) {
        int m = A.size();
        int n = A.get(0).size();
        int targetSize = (m * n + 1) / 2;
        int left = 0;
        int right = Integer.MAX_VALUE;
        while (left <= right) {
            int current = left + (right - left) / 2;
            int smallerSize = findSmallerSize(A, current, m, n);
            // System.out.println(current + ": " + smallerSize);
            if (smallerSize >= targetSize) {
                right = current - 1;
            } else {
                left = current + 1;
            }
        }
        return left;
    }
    
    private int findSmallerSize(ArrayList<ArrayList<Integer>> A, int target, int m, int n) {
        int smallerSize = 0;
        //     r  l
        //  1, 3, 5  - 5
        //r l
        //  3, 3, 5  - 2
        //     r  l
        //  3, 3, 5  - 3
        for (int i = 0; i < m; i ++) {
            ArrayList<Integer> arr = A.get(i);
            int left = 0;
            int right = n - 1;
            while (left <= right) {
                int mid = left + (right - left) / 2;
                if (arr.get(mid) <= target) {
                    left = mid + 1;
                } else {
                    right = mid - 1;
                }
            }
            smallerSize += left;
        }
        return smallerSize;
    }
}