Number of Ways to Wear Different Hats to Each Other

There are n people and 40 types of hats labeled from 1 to 40.

Given a list of list of integers hats, where hats[i] is a list of all hats preferred by the i-th person.

Return the number of ways that the n people wear different hats to each other.

Since the answer may be too large, return it modulo 10^9 + 7.

 

Example 1:

Input: hats = [[3,4],[4,5],[5]]
Output: 1
Explanation: There is only one way to choose hats given the conditions. 
First person choose hat 3, Second person choose hat 4 and last one hat 5.
Example 2:

Input: hats = [[3,5,1],[3,5]]
Output: 4
Explanation: There are 4 ways to choose hats
(3,5), (5,3), (1,3) and (1,5)

Example 3:

Input: hats = [[1,2,3,4],[1,2,3,4],[1,2,3,4],[1,2,3,4]]
Output: 24
Explanation: Each person can choose hats labeled from 1 to 4.
Number of Permutations of (1,2,3,4) = 24.

Example 4:

Input: hats = [[1,2,3],[2,3,5,6],[1,3,7,9],[1,8,9],[2,5,7]]
Output: 111

 

Constraints:


Solution:

class Solution {
    int mod = (int) 1e9 + 7;
    Map<Integer, Map<Integer, Integer>> memo = new HashMap();

    public int numberWays(List<List<Integer>> hats) {
        Map<Integer, List<Integer>> map = new HashMap();
        for (int i = 0; i < hats.size(); i ++) {
            for (int hat : hats.get(i)) {
                map.putIfAbsent(hat, new ArrayList());
                map.get(hat).add(i);
            }
        }
        // System.out.println(map);
        List<Integer> h = new ArrayList(map.keySet());
        // System.out.println(h);
        // System.out.println(hats.size());
        return helper(map, h, 0, 0, hats.size());
    }
    
    private int helper(Map<Integer, List<Integer>> map, List<Integer> hats, int i, int curr, int n) {
        // System.out.println(i + ", " + Integer.toBinaryString(curr));
        if (i == hats.size()) {
            if (Integer.bitCount(curr) == n) return 1;
            else return 0;
        }
        if (memo.get(curr) != null && memo.get(curr).get(i) != null) return memo.get(curr).get(i);
        int ways = 0;
        int h = hats.get(i);
        for (int p : map.get(h)) {
            if ((curr & (1 << p)) == 0) {
                ways = (ways + helper(map, hats, i + 1, curr | (1 << p), n)) % mod;
            }
        }
        ways = (ways + helper(map, hats, i + 1, curr, n)) % mod;
        memo.putIfAbsent(curr, new HashMap());
        memo.get(curr).put(i, ways);
        return ways;
    }
}