Equal Average Partition

Given an array with non negative numbers, divide the array into two parts such that the average of both the parts is equal.
Return both parts (If exist).
If there is no solution. return an empty list.

Example:

Input:
[1 7 15 29 11 9]

Output:
[9 15] [1 7 11 29]

The average of part is (15+9)/2 = 12,
average of second part elements is (1 + 7 + 11 + 29) / 4 = 12


NOTE 1: If a solution exists, you should return a list of exactly 2 lists of integers A and B which follow the following condition :NOTE 2: If multiple solutions exist, return the solution where length(A) is minimum. If there is still a tie, return the one where A is lexicographically smallest. NOTE 3: Array will contain only non negative numbers.
Solution:

public class Solution {
    public int[][] avgset(int[] A) {
        int n = A.length;
        int total = 0;
        for (int i : A) total += i;
        double avg = (double) total / (double) n;
        // dp[i][j] = the size i of left part can sum to j
        // dp[i][j] = dp[i - 1][j - A[k]] for all element k
        // dp[0][0] = true
        // dp[0][j] = false
        boolean[][] dp = new boolean[n + 1][total + 1];
        Set<Integer>[][] t = new HashSet[n + 1][total + 1];
        dp[0][0] = true;
        t[0][0] = new HashSet<>();
        for (int j = 1; j <= total; j ++) {
            dp[0][j] = false;
            t[0][j] = new HashSet<>();
        }
        // System.out.println(Arrays.toString(A));
        for (int i = 1; i <= n; i ++) {
            for (int j = 0; j <= total; j ++) {
                dp[i][j] = false;
                t[i][j] = new HashSet<>();
                for (int k = 0; k < n; k ++) {
                    if (j - A[k] >= 0 && dp[i - 1][j - A[k]] && !t[i - 1][j - A[k]].contains(k)) {
                        dp[i][j] = true;
                        if (t[i][j].isEmpty()) {
                            t[i][j].addAll(t[i - 1][j - A[k]]);
                            t[i][j].add(k);
                        } else {
                            int[] curLeft = getLeft(i, t[i][j], A);
                            Set<Integer> set = new HashSet<>(t[i - 1][j - A[k]]);
                            set.add(k);
                            int[] nextLeft = getLeft(i, set, A);
                            if (smaller(nextLeft, curLeft)) {
                                t[i][j] = set;
                            }
                        }
                    }
                }
            }
        }

        int[] left = null;
        for (int j = 0; j <= total; j ++) {
            int leftSum = j;
            int rightSum = total - j;
            for (int i = 0; i <= n; i ++) {
                if (dp[i][j]) {
                    double leftAvg = (double) leftSum / (double) i;
                    double rightAvg = (double) (total - leftSum) / (double) (n - i);
                    if (Double.compare(leftAvg, rightAvg) == 0) {
                        int[] curLeft = getLeft(i, t[i][j], A);
                        if (left == null || smaller(curLeft, left)) {
                            left = curLeft;
                        }
                    }
                }
            }
        }
        if (left == null) {
            return new int[0][];
        }
        Map<Integer, Integer> leftMap = new HashMap<>();
        for (int i : left) {
            leftMap.put(i, leftMap.getOrDefault(i, 0) + 1);
        }
        int[] right = new int[n - left.length];
        int index = 0;
        for (int i : A) {
            if (leftMap.containsKey(i)) {
                leftMap.put(i, leftMap.get(i) - 1);
                if (leftMap.get(i) == 0) leftMap.remove(i);
            } else {
                right[index++] = i;
            }
        }
        Arrays.sort(right);
        int[][] result = new int[2][];
        result[0] = left;
        result[1] = right;
        return result;
    }

    private boolean smaller(int[] a, int[] b) {
        if (a.length < b.length) {
            return true;
        } else if (a.length > b.length) {
            return false;
        } else {
            for (int i = 0; i < a.length; i ++) {
                int aVal = a[i];
                int bVal = b[i];
                if (aVal < bVal) {
                    return true;
                } else if (aVal > bVal) {
                    return false;
                }
            }
        }
        return false;
    }
    
    private int[] getLeft(int len, Set<Integer> t, int[] A) {
        int[] left = new int[len];
        Set<Integer> set = t;
        int i = 0;
        for (int index : set) {
            left[i ++] = A[index];
        }
        Arrays.sort(left);
        return left;
    }
}