Java中的随机加权选择
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/6409652/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Random weighted selection in Java
提问by yosi
I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight
我想从一个集合中随机选择一个项目,但选择任何项目的机会应该与相关的权重成正比
Example inputs:
示例输入:
item weight
---- ------
sword of misery 10
shield of happy 5
potion of dying 6
triple-edged sword 1
So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.
所以,如果我有 4 种可能的物品,那么获得任何一件没有重量的物品的几率是四分之一。
In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.
在这种情况下,用户获得痛苦之剑的可能性应该是三刃剑的 10 倍。
How do I make a weighted random selection in Java?
如何在 Java 中进行加权随机选择?
回答by Arne Deutsch
You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function. Do something like this:
您不会找到解决此类问题的框架,因为请求的功能只不过是一个简单的函数。做这样的事情:
interface Item {
double getWeight();
}
class RandomItemChooser {
public Item chooseOnWeight(List<Item> items) {
double completeWeight = 0.0;
for (Item item : items)
completeWeight += item.getWeight();
double r = Math.random() * completeWeight;
double countWeight = 0.0;
for (Item item : items) {
countWeight += item.getWeight();
if (countWeight >= r)
return item;
}
throw new RuntimeException("Should never be shown.");
}
}
回答by Peter Lawrey
I would use a NavigableMap
我会使用 NavigableMap
public class RandomCollection<E> {
private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
private final Random random;
private double total = 0;
public RandomCollection() {
this(new Random());
}
public RandomCollection(Random random) {
this.random = random;
}
public RandomCollection<E> add(double weight, E result) {
if (weight <= 0) return this;
total += weight;
map.put(total, result);
return this;
}
public E next() {
double value = random.nextDouble() * total;
return map.higherEntry(value).getValue();
}
}
Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively
假设我有一个概率分别为 40%、35%、25% 的动物狗、猫、马的列表
RandomCollection<String> rc = new RandomCollection<>()
.add(40, "dog").add(35, "cat").add(25, "horse");
for (int i = 0; i < 10; i++) {
System.out.println(rc.next());
}
回答by kdkeck
There is now a class for this in Apache Commons: EnumeratedDistribution
现在在 Apache Commons 中有一个类:EnumeratedDistribution
Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();
where itemWeights
is a List<Pair<Item, Double>>
, like (assuming Item
interface in Arne's answer):
哪里itemWeights
是 a List<Pair<Item, Double>>
,比如(假设Item
Arne 的回答中的接口):
final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
itemWeights.add(new Pair(i, i.getWeight()));
}
or in Java 8:
或在 Java 8 中:
itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());
Note:Pair
here needs to be org.apache.commons.math3.util.Pair
, not org.apache.commons.lang3.tuple.Pair
.
注意:Pair
这里需要是org.apache.commons.math3.util.Pair
,不是org.apache.commons.lang3.tuple.Pair
。
回答by Olivier Grégtheitroade
Use an alias method
使用别名方法
If you're gonna roll a lot of times (as in a game), you should use an alias method.
如果您要多次滚动(如在游戏中),则应使用别名方法。
The code below is rather long implementation of such an alias method, indeed. But this is because of the initialization part. The retrieval of elements is very fast (see the next
and the applyAsInt
methods they don't loop).
下面的代码确实是这种别名方法的相当长的实现。但这是因为初始化部分。元素的检索非常快(请参阅它们不循环next
的applyAsInt
方法和方法)。
Usage
用法
Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;
Random random = new Random();
RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);
Implementation
执行
This implementation:
这个实现:
- uses Java 8;
- is designed to be as fast as possible(well, at least, I tried to do so using micro-benchmarking);
- is totally thread-safe(keep one
Random
in each thread for maximum performance, useThreadLocalRandom
?); - fetches elements in O(1), unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));
- keeps the items independant from their weight, so an item can be assigned various weights in different contexts.
- 使用Java 8;
- 被设计得尽可能快(好吧,至少,我尝试使用微基准测试来做到这一点);
- 是完全线程安全的(
Random
在每个线程中保留一个以获得最大性能,使用ThreadLocalRandom
?); - 在 O(1) 中获取元素,这与您通常在互联网或 StackOverflow 上找到的不同,在那里,幼稚的实现以 O(n) 或 O(log(n)) 运行;
- 使项目与其重量无关,因此可以在不同的上下文中为项目分配各种权重。
Anyways, here's the code. (Note that I maintain an up to date version of this class.)
无论如何,这是代码。(请注意,我维护了此类的最新版本。)
import static java.util.Objects.requireNonNull;
import java.util.*;
import java.util.function.*;
public final class RandomSelector<T> {
public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
throws IllegalArgumentException {
requireNonNull(elements, "elements must not be null");
requireNonNull(weighter, "weighter must not be null");
if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }
// Array is faster than anything. Use that.
int size = elements.size();
T[] elementArray = elements.toArray((T[]) new Object[size]);
double totalWeight = 0d;
double[] discreteProbabilities = new double[size];
// Retrieve the probabilities
for (int i = 0; i < size; i++) {
double weight = weighter.applyAsDouble(elementArray[i]);
if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
discreteProbabilities[i] = weight;
totalWeight += weight;
}
if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }
// Normalize the probabilities
for (int i = 0; i < size; i++) {
discreteProbabilities[i] /= totalWeight;
}
return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
}
private final T[] elements;
private final ToIntFunction<Random> selection;
private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
this.elements = elements;
this.selection = selection;
}
public T next(Random random) {
return elements[selection.applyAsInt(random)];
}
private static class RandomWeightedSelection implements ToIntFunction<Random> {
// Alias method implementation O(1)
// using Vose's algorithm to initialize O(n)
private final double[] probabilities;
private final int[] alias;
RandomWeightedSelection(double[] probabilities) {
int size = probabilities.length;
double average = 1.0d / size;
int[] small = new int[size];
int smallSize = 0;
int[] large = new int[size];
int largeSize = 0;
// Describe a column as either small (below average) or large (above average).
for (int i = 0; i < size; i++) {
if (probabilities[i] < average) {
small[smallSize++] = i;
} else {
large[largeSize++] = i;
}
}
// For each column, saturate a small probability to average with a large probability.
while (largeSize != 0 && smallSize != 0) {
int less = small[--smallSize];
int more = large[--largeSize];
probabilities[less] = probabilities[less] * size;
alias[less] = more;
probabilities[more] += probabilities[less] - average;
if (probabilities[more] < average) {
small[smallSize++] = more;
} else {
large[largeSize++] = more;
}
}
// Flush unused columns.
while (smallSize != 0) {
probabilities[small[--smallSize]] = 1.0d;
}
while (largeSize != 0) {
probabilities[large[--largeSize]] = 1.0d;
}
}
@Override public int applyAsInt(Random random) {
// Call random once to decide which column will be used.
int column = random.nextInt(probabilities.length);
// Call random a second time to decide which will be used: the column or the alias.
if (random.nextDouble() < probabilities[column]) {
return column;
} else {
return alias[column];
}
}
}
}
回答by ronen
public class RandomCollection<E> {
private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
private double total = 0;
public void add(double weight, E result) {
if (weight <= 0 || map.containsValue(result))
return;
total += weight;
map.put(total, result);
}
public E next() {
double value = ThreadLocalRandom.current().nextDouble() * total;
return map.ceilingEntry(value).getValue();
}
}
回答by Yuri Heiko
If you need to remove elements after choosing you can use another solution. Add all the elements into a 'LinkedList', each element must be added as many times as it weight is, then use Collections.shuffle()
which, according to JavaDoc
如果您需要在选择后删除元素,您可以使用其他解决方案。Collections.shuffle()
根据JavaDoc,将所有元素添加到“LinkedList”中,每个元素必须添加与其权重相同的次数,然后使用哪个
Randomly permutes the specified list using a default source of randomness. All permutations occur with approximately equal likelihood.
使用默认的随机源随机排列指定的列表。所有排列都以近似相等的可能性发生。
Finally, get and remove elements using pop()
or removeFirst()
最后,使用pop()
or获取和删除元素removeFirst()
Map<String, Integer> map = new HashMap<String, Integer>() {{
put("Five", 5);
put("Four", 4);
put("Three", 3);
put("Two", 2);
put("One", 1);
}};
LinkedList<String> list = new LinkedList<>();
for (Map.Entry<String, Integer> entry : map.entrySet()) {
for (int i = 0; i < entry.getValue(); i++) {
list.add(entry.getKey());
}
}
Collections.shuffle(list);
int size = list.size();
for (int i = 0; i < size; i++) {
System.out.println(list.pop());
}
回答by Quinton Gordon
139
139
There is a straightforward algorithm for picking an item at random, where items have individual weights:
有一个简单的算法可以随机选择一个项目,其中项目具有单独的权重:
calculate the sum of all the weights
pick a random number that is 0 or greater and is less than the sum of the weights
go through the items one at a time, subtracting their weight from your random number until you get the item where the random number is less than that item's weight
计算所有权重的总和
选择一个大于等于 0 且小于权重之和的随机数
一次检查一件物品,从你的随机数中减去它们的重量,直到你得到随机数小于该物品重量的物品