INVERSIONS

Given an array A, count the number of inversions in the array.

Formally speaking, two elements A[i] and A[j] form an inversion if A[i] > A[j] and i < j

Example:

A : [2, 4, 1, 3, 5]
Output : 3

as the 3 inversions are (2, 1), (4, 1), (4, 3).

Method1:

Use TreeMap

Solution1:

Time: O(n^2logn)
Space: O(n)

public class Solution {
    public int countInversions(ArrayList<Integer> A) {
        TreeMap<Integer, Integer> map = new TreeMap<>();
        int count = 0;
        for (int i = A.size() - 1; i >= 0; i --) {
            int val = A.get(i);
            SortedMap<Integer, Integer> smaller = map.headMap(val);
            for (Map.Entry<Integer, Integer> entry : smaller.entrySet()) {
                count += entry.getValue();
            }
            map.put(val, map.getOrDefault(val, 0) + 1);
        }
        return count;
    }
}

Method2:

Use Binary Index Tree

Solution2:

Time: O(nlogn)
Space: O(n)

public class Solution {
    static class Bit {
        public int[] arr;
       
        public Bit(int n) {
            this.arr = new int[n + 1];
        }
       
        public void update(int index, int v) {
            index += 1;
            // add v to index and left trees that contain it
            while (index < this.arr.length) {
                this.arr[index] += v;
                index += index & (-index);
            }
        }
       
        public int getSum(int index) {
            // sum up index and sub trees smaller than it
            index += 1;
            int sum = 0;
            while (index > 0) {
                sum += this.arr[index];
                index -= index & (-index);
            }
            return sum;
        }
    }
   
    public int countInversions(ArrayList<Integer> A) {
        ArrayList<Integer> list = new ArrayList<>(A);
        Collections.sort(list);
        for (int i = 0; i < A.size(); i ++) {
            int val = A.get(i);
            int newVal = Collections.binarySearch(list, val) + 1;
            A.set(i, newVal);
        }
        Bit bit = new Bit(A.size());
        int sum = 0;
        for (int i = A.size() - 1; i >= 0; i --) {
            sum += bit.getSum(A.get(i) - 1);
            bit.update(A.get(i), 1);
            // System.out.println(Arrays.toString(bit.arr));
        }
        return sum;
    }
}