Java中的随机加权选择

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/6409652/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-16 05:42:41  来源:igfitidea点击:

Random weighted selection in Java

javarandomdouble

提问by yosi

I want to choose a random item from a set, but the chance of choosing any item should be proportional to the associated weight

我想从一个集合中随机选择一个项目,但选择任何项目的机会应该与相关的权重成正比

Example inputs:

示例输入:

item                weight
----                ------
sword of misery         10
shield of happy          5
potion of dying          6
triple-edged sword       1

So, if I have 4 possible items, the chance of getting any one item without weights would be 1 in 4.

所以,如果我有 4 种可能的物品,那么获得任何一件没有重量的物品的几率是四分之一。

In this case, a user should be 10 times more likely to get the sword of misery than the triple-edged sword.

在这种情况下,用户获得痛苦之剑的可能性应该是三刃剑的 10 倍。

How do I make a weighted random selection in Java?

如何在 Java 中进行加权随机选择?

回答by Arne Deutsch

You will not find a framework for this kind of problem, as the requested functionality is nothing more then a simple function. Do something like this:

您不会找到解决此类问题的框架,因为请求的功能只不过是一个简单的函数。做这样的事情:

interface Item {
    double getWeight();
}

class RandomItemChooser {
    public Item chooseOnWeight(List<Item> items) {
        double completeWeight = 0.0;
        for (Item item : items)
            completeWeight += item.getWeight();
        double r = Math.random() * completeWeight;
        double countWeight = 0.0;
        for (Item item : items) {
            countWeight += item.getWeight();
            if (countWeight >= r)
                return item;
        }
        throw new RuntimeException("Should never be shown.");
    }
}

回答by Peter Lawrey

I would use a NavigableMap

我会使用 NavigableMap

public class RandomCollection<E> {
    private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
    private final Random random;
    private double total = 0;

    public RandomCollection() {
        this(new Random());
    }

    public RandomCollection(Random random) {
        this.random = random;
    }

    public RandomCollection<E> add(double weight, E result) {
        if (weight <= 0) return this;
        total += weight;
        map.put(total, result);
        return this;
    }

    public E next() {
        double value = random.nextDouble() * total;
        return map.higherEntry(value).getValue();
    }
}

Say I have a list of animals dog, cat, horse with probabilities as 40%, 35%, 25% respectively

假设我有一个概率分别为 40%、35%、25% 的动物狗、猫、马的列表

RandomCollection<String> rc = new RandomCollection<>()
                              .add(40, "dog").add(35, "cat").add(25, "horse");

for (int i = 0; i < 10; i++) {
    System.out.println(rc.next());
} 

回答by kdkeck

There is now a class for this in Apache Commons: EnumeratedDistribution

现在在 Apache Commons 中有一个类:EnumeratedDistribution

Item selectedItem = new EnumeratedDistribution<>(itemWeights).sample();

where itemWeightsis a List<Pair<Item, Double>>, like (assuming Iteminterface in Arne's answer):

哪里itemWeights是 a List<Pair<Item, Double>>,比如(假设ItemArne 的回答中的接口):

final List<Pair<Item, Double>> itemWeights = Collections.newArrayList();
for (Item i: itemSet) {
    itemWeights.add(new Pair(i, i.getWeight()));
}

or in Java 8:

或在 Java 8 中:

itemSet.stream().map(i -> new Pair(i, i.getWeight())).collect(toList());

Note:Pairhere needs to be org.apache.commons.math3.util.Pair, not org.apache.commons.lang3.tuple.Pair.

注意:Pair这里需要是org.apache.commons.math3.util.Pair,不是org.apache.commons.lang3.tuple.Pair

回答by Olivier Grégtheitroade

Use an alias method

使用别名方法

If you're gonna roll a lot of times (as in a game), you should use an alias method.

如果您要多次滚动(如在游戏中),则应使用别名方法。

The code below is rather long implementation of such an alias method, indeed. But this is because of the initialization part. The retrieval of elements is very fast (see the nextand the applyAsIntmethods they don't loop).

下面的代码确实是这种别名方法的相当长的实现。但这是因为初始化部分。元素的检索非常快(请参阅它们不循环nextapplyAsInt方法和方法)。

Usage

用法

Set<Item> items = ... ;
ToDoubleFunction<Item> weighter = ... ;

Random random = new Random();

RandomSelector<T> selector = RandomSelector.weighted(items, weighter);
Item drop = selector.next(random);

Implementation

执行

This implementation:

这个实现:

  • uses Java 8;
  • is designed to be as fast as possible(well, at least, I tried to do so using micro-benchmarking);
  • is totally thread-safe(keep one Randomin each thread for maximum performance, use ThreadLocalRandom?);
  • fetches elements in O(1), unlike what you mostly find on the internet or on StackOverflow, where naive implementations run in O(n) or O(log(n));
  • keeps the items independant from their weight, so an item can be assigned various weights in different contexts.
  • 使用Java 8
  • 被设计得尽可能快(好吧,至少,我尝试使用微基准测试来做到这一点);
  • 是完全线程安全的Random在每个线程中保留一个以获得最大性能,使用ThreadLocalRandom?);
  • 在 O(1) 中获取元素,这与您通常在互联网或 StackOverflow 上找到的不同,在那里,幼稚的实现以 O(n) 或 O(log(n)) 运行;
  • 使项目与其重量无关,因此可以在不同的上下文中为项目分配各种权重。

Anyways, here's the code. (Note that I maintain an up to date version of this class.)

无论如何,这是代码。(请注意,我维护了此类的最新版本。)

import static java.util.Objects.requireNonNull;

import java.util.*;
import java.util.function.*;

public final class RandomSelector<T> {

  public static <T> RandomSelector<T> weighted(Set<T> elements, ToDoubleFunction<? super T> weighter)
      throws IllegalArgumentException {
    requireNonNull(elements, "elements must not be null");
    requireNonNull(weighter, "weighter must not be null");
    if (elements.isEmpty()) { throw new IllegalArgumentException("elements must not be empty"); }

    // Array is faster than anything. Use that.
    int size = elements.size();
    T[] elementArray = elements.toArray((T[]) new Object[size]);

    double totalWeight = 0d;
    double[] discreteProbabilities = new double[size];

    // Retrieve the probabilities
    for (int i = 0; i < size; i++) {
      double weight = weighter.applyAsDouble(elementArray[i]);
      if (weight < 0.0d) { throw new IllegalArgumentException("weighter may not return a negative number"); }
      discreteProbabilities[i] = weight;
      totalWeight += weight;
    }
    if (totalWeight == 0.0d) { throw new IllegalArgumentException("the total weight of elements must be greater than 0"); }

    // Normalize the probabilities
    for (int i = 0; i < size; i++) {
      discreteProbabilities[i] /= totalWeight;
    }
    return new RandomSelector<>(elementArray, new RandomWeightedSelection(discreteProbabilities));
  }

  private final T[] elements;
  private final ToIntFunction<Random> selection;

  private RandomSelector(T[] elements, ToIntFunction<Random> selection) {
    this.elements = elements;
    this.selection = selection;
  }

  public T next(Random random) {
    return elements[selection.applyAsInt(random)];
  }

  private static class RandomWeightedSelection implements ToIntFunction<Random> {
    // Alias method implementation O(1)
    // using Vose's algorithm to initialize O(n)

    private final double[] probabilities;
    private final int[] alias;

    RandomWeightedSelection(double[] probabilities) {
      int size = probabilities.length;

      double average = 1.0d / size;
      int[] small = new int[size];
      int smallSize = 0;
      int[] large = new int[size];
      int largeSize = 0;

      // Describe a column as either small (below average) or large (above average).
      for (int i = 0; i < size; i++) {
        if (probabilities[i] < average) {
          small[smallSize++] = i;
        } else {
          large[largeSize++] = i;
        }
      }

      // For each column, saturate a small probability to average with a large probability.
      while (largeSize != 0 && smallSize != 0) {
        int less = small[--smallSize];
        int more = large[--largeSize];
        probabilities[less] = probabilities[less] * size;
        alias[less] = more;
        probabilities[more] += probabilities[less] - average;
        if (probabilities[more] < average) {
          small[smallSize++] = more;
        } else {
          large[largeSize++] = more;
        }
      }

      // Flush unused columns.
      while (smallSize != 0) {
        probabilities[small[--smallSize]] = 1.0d;
      }
      while (largeSize != 0) {
        probabilities[large[--largeSize]] = 1.0d;
      }
    }

    @Override public int applyAsInt(Random random) {
      // Call random once to decide which column will be used.
      int column = random.nextInt(probabilities.length);

      // Call random a second time to decide which will be used: the column or the alias.
      if (random.nextDouble() < probabilities[column]) {
        return column;
      } else {
        return alias[column];
      }
    }
  }
}

回答by ronen

public class RandomCollection<E> {
  private final NavigableMap<Double, E> map = new TreeMap<Double, E>();
  private double total = 0;

  public void add(double weight, E result) {
    if (weight <= 0 || map.containsValue(result))
      return;
    total += weight;
    map.put(total, result);
  }

  public E next() {
    double value = ThreadLocalRandom.current().nextDouble() * total;
    return map.ceilingEntry(value).getValue();
  }
}

回答by Yuri Heiko

If you need to remove elements after choosing you can use another solution. Add all the elements into a 'LinkedList', each element must be added as many times as it weight is, then use Collections.shuffle()which, according to JavaDoc

如果您需要在选择后删除元素,您可以使用其他解决方案。Collections.shuffle()根据JavaDoc,将所有元素添加到“LinkedList”中,每个元素必须添加与其权重相同的次数,然后使用哪个

Randomly permutes the specified list using a default source of randomness. All permutations occur with approximately equal likelihood.

使用默认的随机源随机排列指定的列表。所有排列都以近似相等的可能性发生。

Finally, get and remove elements using pop()or removeFirst()

最后,使用pop()or获取和删除元素removeFirst()

Map<String, Integer> map = new HashMap<String, Integer>() {{
    put("Five", 5);
    put("Four", 4);
    put("Three", 3);
    put("Two", 2);
    put("One", 1);
}};

LinkedList<String> list = new LinkedList<>();

for (Map.Entry<String, Integer> entry : map.entrySet()) {
    for (int i = 0; i < entry.getValue(); i++) {
        list.add(entry.getKey());
    }
}

Collections.shuffle(list);

int size = list.size();
for (int i = 0; i < size; i++) {
    System.out.println(list.pop());
}

回答by Quinton Gordon

139

139

There is a straightforward algorithm for picking an item at random, where items have individual weights:

有一个简单的算法可以随机选择一个项目,其中项目具有单独的权重:

  1. calculate the sum of all the weights

  2. pick a random number that is 0 or greater and is less than the sum of the weights

  3. go through the items one at a time, subtracting their weight from your random number until you get the item where the random number is less than that item's weight

  1. 计算所有权重的总和

  2. 选择一个大于等于 0 且小于权重之和的随机数

  3. 一次检查一件物品,从你的随机数中减去它们的重量,直到你得到随机数小于该物品重量的物品