Making A Large Island

In a 2D grid of 0s and 1s, we change at most one 0 to a 1.

After, what is the size of the largest island? (An island is a 4-directionally connected group of 1s).

Example 1:

Input: [[1, 0], [0, 1]]
Output: 3
Explanation: Change one 0 to 1 and connect two 1s, then we get an island with area = 3.

Example 2:

Input: [[1, 1], [1, 0]]
Output: 4
Explanation: Change the 0 to 1 and make the island bigger, only one island with area = 4.
Example 3:

Input: [[1, 1], [1, 1]]
Output: 4
Explanation: Can't change any 0 to 1, only one island with area = 4.
 

Notes:


Solution:

class Solution {
    static class UF {
        int[] parent;
        int[] size;
        
        public UF(int n) {
            parent = new int[n];
            size = new int[n];
            for (int i = 0; i < n; i ++) {
                parent[i] = i;
                size[i] = 1;
            }
        }
        
        public int union(int x, int y) {
            int rootX = find(x);
            int rootY = find(y);
            if (rootX == rootY) return 0;
            if (size[rootX] >= size[rootY]) {
                parent[rootY] = rootX;
                size[rootX] += size[rootY];
                return size[rootX];
            } else {
                parent[rootX] = rootY;
                size[rootY] += size[rootX];
                return size[rootY];
            }
        }
        
        private int find(int x) {
            while (parent[x] != x) {
                parent[x] = parent[parent[x]];
                x = parent[x];
            }
            return x;
        }
    }
    
    int[] dx = new int[]{-1, 0, 0, 1};
    int[] dy = new int[]{0, -1, 1, 0};
    
    public int largestIsland(int[][] grid) {
        int m = grid.length, n = grid[0].length;
        UF islands = new UF(m * n);
        int res = 1;
        for (int i = 0; i < m; i ++) {
            for (int j = 0; j < n; j ++) {
                if (grid[i][j] == 1) {
                    int curr = i * n + j;
                    for (int k = 0; k < 4; k ++) {
                        int nx = i + dx[k];
                        int ny = j + dy[k];
                        if (nx >= 0 && nx < m && ny >= 0 && ny < n && grid[nx][ny] == 1) {
                            int next = nx * n + ny;
                            res = Math.max(res, islands.union(curr, next));
                        }
                    }
                }
            }
        }
        for (int i = 0; i < m; i ++) {
            for (int j = 0; j < n; j ++) {
                if (grid[i][j] == 0) {
                    int curr = i * n + j;
                    UF copy = new UF(m * n);
                    copy.parent = islands.parent.clone();
                    copy.size = islands.size.clone();
                    for (int k = 0; k < 4; k ++) {
                        int nx = i + dx[k];
                        int ny = j + dy[k];
                        if (nx >= 0 && nx < m && ny >= 0 && ny < n && grid[nx][ny] == 1) {
                            int next = nx * n + ny;
                            res = Math.max(res, copy.union(curr, next));
                        }
                    }
                }
            }
        }
        return res;
    }
}