Minimum Incompatibility

You are given an integer array nums​​​ and an integer k. You are asked to distribute this array into k subsets of equal size such that there are no two equal elements in the same subset.

A subset's incompatibility is the difference between the maximum and minimum elements in that array.

Return the minimum possible sum of incompatibilities of the k subsets after distributing the array optimally, or return -1 if it is not possible.

A subset is a group integers that appear in the array with no particular order.

 

Example 1:

Input: nums = [1,2,1,4], k = 2
Output: 4
Explanation: The optimal distribution of subsets is [1,2] and [1,4].
The incompatibility is (2-1) + (4-1) = 4.
Note that [1,1] and [2,4] would result in a smaller sum, but the first subset contains 2 equal elements.
Example 2:

Input: nums = [6,3,8,1,3,1,2,2], k = 4
Output: 6
Explanation: The optimal distribution of subsets is [1,2], [2,3], [6,8], and [1,3].
The incompatibility is (2-1) + (3-2) + (8-6) + (3-1) = 6.

Example 3:

Input: nums = [5,3,3,6,3,3], k = 3
Output: -1
Explanation: It is impossible to distribute nums into 3 subsets where no two elements are equal in the same subset.

 

Constraints:


Solution:

class Solution {
    public int minimumIncompatibility(int[] nums, int k) {
        // m = nums.length / k
        // the incompatibility of bitmask b, if b represents a valid subset of size m
        int n = nums.length, m = n / k;
        Map<Integer, Integer> subsetRes = calc(nums, m);
        Integer[] dp = new Integer[1 << n];
        dp[0] = 0;
        for (int i = 1; i < (1 << n); i ++) {
            if (Integer.bitCount(i) % m != 0) continue;
            for (int j = i; j > 0; j = i & (j - 1)) {
                if (!subsetRes.containsKey(j)) continue;
                if (dp[i - j] == null) continue;
                if (dp[i] == null) dp[i] = subsetRes.get(j) + dp[i - j];
                else dp[i] = Math.min(dp[i], subsetRes.get(j) + dp[i - j]);
            }
        }
        return dp[(1 << n) - 1] == null ? -1 : dp[(1 << n) - 1];
    }
    
    private Map<Integer, Integer> calc(int[] nums, int m) {
        int n = nums.length;
        Map<Integer, Integer> res = new HashMap();
        for (int i = 1; i < (1 << n); i ++) {
            if (Integer.bitCount(i) != m) continue;
            int max = Integer.MIN_VALUE;
            int min = Integer.MAX_VALUE;
            int set = 0;
            boolean noRepeat = true;
            for (int j = 0; j < n; j ++) {
                if ((i & (1 << j)) > 0) {
                    int val = nums[j];
                    if ((set & (1 << val)) > 0) {
                        noRepeat = false;
                        break;
                    }
                    max = Math.max(max, val);
                    min = Math.min(min, val);
                    set |= (1 << val);
                }
            }
            if (noRepeat) res.put(i, max - min);
        }
        return res;
    }
}