Random Pick with Weight

Given an array w of positive integers, where w[i] describes the weight of index i(0-indexed), write a function pickIndex which randomly picks an index in proportion to its weight.

For example, given an input list of values w = [2, 8], when we pick up a number out of it, the chance is that 8 times out of 10 we should pick the number 1 as the answer since it's the second element of the array (w[1] = 8).

 

Example 1:

Input
["Solution","pickIndex"]
[[[1]],[]]
Output
[null,0]

Explanation
Solution solution = new Solution([1]);
solution.pickIndex(); // return 0. Since there is only one single element on the array the only option is to return the first element.

Example 2:

Input
["Solution","pickIndex","pickIndex","pickIndex","pickIndex","pickIndex"]
[[[1,3]],[],[],[],[],[]]
Output
[null,1,1,1,1,0]

Explanation
Solution solution = new Solution([1, 3]);
solution.pickIndex(); // return 1. It's returning the second element (index = 1) that has probability of 3/4.
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 1
solution.pickIndex(); // return 0. It's returning the first element (index = 0) that has probability of 1/4.

Since this is a randomization problem, multiple answers are allowed so the following outputs can be considered correct :
[null,1,1,1,1,0]
[null,1,1,1,1,1]
[null,1,1,1,0,0]
[null,1,1,1,0,1]
[null,1,0,1,0,0]
......
and so on.

 

Constraints:


Solution:

import java.util.Random;
import java.util.concurrent.ThreadLocalRandom;

class Solution {
    private int[] weight;

    public Solution(int[] w) {
        weight = new int[w.length];
        weight[0] = w[0];
        for (int i = 1; i < w.length; i ++) {
            weight[i] = weight[i - 1] + w[i]; 
        }
        // System.out.println(Arrays.toString(weight));
    }
    
    public int pickIndex() {
        // Random rand = new Random();
        // int number = rand.nextInt(1, weight[weight.length - 1] + 1);
        int number = ThreadLocalRandom.current().nextInt(1, weight[weight.length - 1] + 1);
        return firstGreaterOrEqual(weight, number);
    }
    
    private int firstGreaterOrEqual(int[] arr, int target) {
        int left = 0;
        int right = arr.length - 1;
        while (left <= right) {
            int mid = left + (right - left) / 2;
            if (arr[mid] < target) {
                left = mid + 1;
            } else if (arr[mid] > target) {
                right = mid - 1;
            } else {
                // System.out.println(target + ", " + mid );
                return mid;
            }
        }
        // System.out.println(target + ", " + Math.max(arr.length - 1, left) );
        return left;
    }
}

/**
 * Your Solution object will be instantiated and called as such:
 * Solution obj = new Solution(w);
 * int param_1 = obj.pickIndex();
 */