Sorted Permutation Rank with Repeats

Given a string, find the rank of the string amongst its permutations sorted lexicographically. 
Note that the characters might be repeated. If the characters are repeated, we need to look at the rank in unique permutations. 
Look at the example for more details.

Example :

Input : 'aba'
Output : 2

The order permutations with letters 'a', 'a', and 'b' : 
aab
aba
baa

The answer might not fit in an integer, so return your answer % 1000003

NOTE: 1000003 is a prime number
NOTE: Assume the number of characters in string < 1000003 
思路:

这题的follow up,区别在于之前我们可以用n!来算有多少组合,现在由于有repeated,要从中去除。
n! /  ( a! * b! * c! ...)
a, b, c = 每个数repeat的个数
这里还要注意的是因为我们要取modular,所以不能直接用除法:(1/A) % MOD = A ^ (MOD - 2) % MOD

Solution:

Time: O(n^3)
Space: O(n)

public class Solution {
    private int[] fact(int A) {
        int[] fact = new int[A + 1];
        fact[0] = 1;
        for (int i = 1; i <= A; i ++) {
            fact[i] = (fact[i - 1] * i) % 1000003;
        }
        return fact;
    }
    
    private long modExp(long xint, long yint) {
        int M = 1000003;
        long res = 1;
        long x = xint;
        long y = yint;
        while (y > 0) {
            if ((y & 1) == 1) {
                res = (res * x) % M;
            }
            x = (x * x) % M;
            y >>= 1;
        }
        return res;
    }
    
    public int findRank(String A) {
        char[] arr = A.toCharArray();
        int[] fact = fact(arr.length);
        TreeMap<Integer, Integer> map = new TreeMap<>();
        for (int i = 0; i < arr.length; i ++) {
            int c = arr[i];
            map.put(c, map.getOrDefault(c, 0) + 1);
        }
        long rank = 1;
        for (int i = 0; i < arr.length; i ++) {
            int c = arr[i];
            Map<Integer, Integer> smaller = map.headMap(c);
            for (int key : smaller.keySet()) {
                long redundency = 1;
                for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
                    if (entry.getKey() != key) {
                        redundency = (redundency * fact[entry.getValue()]) % 1000003;
                    } else {
                        redundency = (redundency * fact[entry.getValue() - 1]) % 1000003;
                    }
                }
                // (1/A) % MOD = A ^ (MOD - 2) % MOD
                rank += (fact[arr.length - i - 1] * modExp(redundency, 1000001)) % 1000003;
            }

            map.put(c, map.get(c) - 1);
            if (map.get(c) <= 0) map.remove(c);
    
        }
        return (int) (rank % 1000003);
    }
}