Solution: Random Pick with Weight
Let's solve the Random Pick with Weight problem using the Modified Binary Search pattern.
Statement
You’re given an array of positive integers, weights
, where weights[i]
is the weight of the index.
Write a function, Pick Index(), which performs weighted random selection to return an index from the weights
array. The larger the value of weights[i]
, the heavier the weight is, and the higher the chances of its index being picked.
Suppose that the array consists of the weights . In this case, the probabilities of picking the indexes will be as follows:
-
Index 0:
-
Index 1:
-
Index 2:
Constraints:
-
weights.length
-
weights[i]
-
Pick Index() will be called at most times.
Note: Since we’re randomly choosing from the options, there is no guarantee that in any specific run of the program, any of the elements will be selected with the exact expected frequency.
Solution
So far, you’ve probably brainstormed some approaches and have an idea of how to solve this problem. Let’s explore some of these approaches and figure out which one to follow based on considerations such as time complexity and any implementation constraints.
Naive approach
To correctly add bias to the randomized pick, we need to generate a probability line where the length of each segment is determined by the corresponding weight in the input array. Positions with greater weights have longer segments on the probability line, and positions with smaller weights have shorter segments. We can represent this probability line as a list of running sums of weights.
Let’s implement a linear search approach as our naive solution:
- Starting off with an example of an input list of weights containing , the list of the running sums of weights is .
The segment length can be used as the likelihood metric of each index. The likelihood metric (based on the running sums) for each weight is shown in the diagram to the right.
-
Next, a random number between the range of and the total sum of weights is generated to make a biased selection.
-
We then locate the position of this number on the probability line, corresponding to an index in the input array.
-
By scanning the running sums list, we find the first sum larger than the random number and return the corresponding index.
Both the time and the space complexity of this naive approach is .
Optimized approach using modified binary search
The algorithm’s essence lies in running sums and binary search for efficient index selection based on weights. Running sums are a sequence where each element is the total sum of all previous elements in a list, including the current one. First, transform the input weights into a list of running sums. When selecting an index, a random number is generated from 1 to the highest running sum. A binary search is then used to find the corresponding index where this random number falls within its running sum range. This approach enables a weighted random selection, ensuring that the probability of selecting an index is directly proportional to its weight.
Here’s how the algorithm proceeds:
-
In Init(), we generate the list of running sums from the given list of weights so that we don’t have to compute it again every time we call Pick Index().
-
The Pick Index() method returns an index at random, taking into account the weights provided:
-
Generate a random number,
target
, between and , where is the largest value in the list of running sums of weights. -
Use binary search to find the index of the first running sum that is greater than the random value. Initialize the
low
index to and thehigh
index to the length of the list of running sums of weights. While thelow
index is less than or equal to thehigh
index:-
Calculate the
mid
index aslow
(high
low
. -
If the running sum at the
mid
index is less than or equal totarget
, update thelow
index tomid + 1
. -
Otherwise, update the
high
index tomid
.
-
-
At the end of the binary search, the
low
pointer will point to the index of the first running sum greater thantarget
. Return the index found as the chosen index.
-
Let’s look at the code for this solution below:
import java.util.*;class RandomPickWithWeight {private List<Integer> runningSums;private int totalSum;public RandomPickWithWeight(int[] weights) {runningSums = new ArrayList<>();int runningSum = 0;for (int w : weights) {runningSum += w;runningSums.add(runningSum);}totalSum = runningSum;}public int pickIndex() {Random random = new Random();int target = random.nextInt(totalSum) + 1;int low = 0;int high = runningSums.size();while (low < high) {int mid = low + (high - low) / 2;if (target > runningSums.get(mid)) {low = mid + 1;} else {high = mid;}}return low;}public static void main(String[] args) {int counter = 900;int[][] weights = {{1, 2, 3, 4, 5},{1, 12, 23, 34, 45, 56, 67, 78, 89, 90},{10, 20, 30, 40, 50},{1, 10, 23, 32, 41, 56, 62, 75, 87, 90},{12, 20, 35, 42, 55},{10, 10, 10, 10, 10},{10, 10, 20, 20, 20, 30},{1, 2, 3},{10, 20, 30, 40},{5, 10, 15, 20, 25, 30}};HashMap<Integer, Integer> map = new HashMap<>();for (int i = 0; i < weights.length; i++) {System.out.println((i + 1) + ".\tList of weights: " + Arrays.toString(weights[i]) + ", pick_index() called " + counter + " times" + "\n");for (int l = 0; l < weights[i].length; l++) {map.put(l, 0);}RandomPickWithWeight sol = new RandomPickWithWeight(weights[i]);for (int j = 0; j < counter; j++) {int index = sol.pickIndex();map.put(index, map.get(index) + 1);}System.out.println(new String(new char[100]).replace('\0', '-'));System.out.println("\t" + String.format("%-10s%-5s%-10s%-5s%-15s%-5s%-20s%-5s%-15s","Indexes", "|", "Weights", "|", "Occurrences", "|", "Actual Frequency", "|", "Expected Frequency"));System.out.println(new String(new char[100]).replace('\0', '-'));for (Map.Entry<Integer, Integer> entry : map.entrySet()) {int key = entry.getKey();int value = entry.getValue();System.out.println("\t" + String.format("%-10s%-5s%-10s%-5s%-15s%-5s%-20s%-5s%-15s",key, "|", weights[i][key], "|", value, "|",String.format("%.2f", ((double) value / counter) * 100) + "%", "|",String.format("%.2f", ((double) weights[i][key] / sum(weights[i])) * 100) + "%"));}map.clear();System.out.println(new String(new char[100]).replace('\0', '-'));}}private static int sum(int[] arr) {int total = 0;for (int num : arr) {total += num;}return total;}}
Solution summary
To recap, the solution to this problem can be divided into the following three parts:
-
Generate a list of running sums from the given list of weights.
-
Generate a random number, the range for the random number begins from and ends at the largest number in the running sums list.
-
Use binary search to find the index of the first running sum that is greater than the random value.
Time complexity
Constructor: Since the list of running sums list is calculated in the constructor, the time complexity for the constructor is , where is the size of the array of weights.
Pick Index(): Since we’re performing a binary search on a list of length , the time complexity is .
Space complexity
Constructor: The list of running sums takes space during its construction.
Pick Index(): This function takes space, since constant space is utilized.
Level up your interview prep. Join Educative to access 70+ hands-on prep courses.