scala Spark - 随机数生成

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/36455029/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-10-22 08:08:54  来源:igfitidea点击:

Spark - Random Number Generation

scalarandomapache-sparkspark-dataframe

提问by Brian

I have written a method that must consider a random number to simulate a Bernoulli distribution. I am using random.nextDoubleto generate a number between 0 and 1 then making my decision based on that value given my probability parameter.

我写了一个方法,它必须考虑一个随机数来模拟伯努利分布。我使用random.nextDouble生成 0 到 1 之间的数字,然后根据给定概率参数的值做出决定。

My problem is that Spark is generating the same random numbers within each iteration of my for loop mapping function. I am using the DataFrameAPI. My code follows this format:

我的问题是 Spark 在我的 for 循环映射函数的每次迭代中生成相同的随机数。我正在使用DataFrameAPI。我的代码遵循这种格式:

val myClass = new MyClass()
val M = 3
val myAppSeed = 91234
val rand = new scala.util.Random(myAppSeed)

for (m <- 1 to M) {
  val newDF = sqlContext.createDataFrame(myDF
    .map{row => RowFactory
      .create(row.getString(0),
        myClass.myMethod(row.getString(2), rand.nextDouble())
    }, myDF.schema)
}

Here is the class:

这是课程:

class myClass extends Serializable {
  val q = qProb

  def myMethod(s: String, rand: Double) = {
    if (rand <= q) // do something
    else // do something else
  }
}

I need a new random number every time myMethodis called. I also tried generating the number inside my method with java.util.Random(scala.util.Randomv10 does not extend Serializable) like below, but I'm still getting the same numbers within each for loop

每次myMethod调用时我都需要一个新的随机数。我还尝试在我的方法中生成数字java.util.Randomscala.util.Randomv10 不扩展Serializable),如下所示,但我仍然在每个 for 循环中获得相同的数字

val r = new java.util.Random(s.hashCode.toLong)
val rand = r.nextDouble()

I've done some research, and it seems this has do to with Sparks deterministic nature.

我做了一些研究,这似乎与 Sparks 的确定性性质有关。

采纳答案by Pascal Soucy

The reason why the same sequence is repeated is that the random generator is created and initialized with a seed before the data is partitioned. Each partition then starts from the same random seed. Maybe not the most efficient way to do it, but the following should work:

重复相同序列的原因是随机生成器是在数据分区之前创建并用种子初始化的。然后每个分区从相同的随机种子开始。也许不是最有效的方法,但以下应该有效:

val myClass = new MyClass()
val M = 3

for (m <- 1 to M) {
  val newDF = sqlContext.createDataFrame(myDF
    .map{ 
       val rand = scala.util.Random
       row => RowFactory
      .create(row.getString(0),
        myClass.myMethod(row.getString(2), rand.nextDouble())
    }, myDF.schema)
}

回答by David Griffin

Just use the SQL function rand:

只需使用 SQL 函数rand

import org.apache.spark.sql.functions._

//df: org.apache.spark.sql.DataFrame = [key: int]

df.select($"key", rand() as "rand").show
+---+-------------------+
|key|               rand|
+---+-------------------+
|  1| 0.8635073400704648|
|  2| 0.6870153659986652|
|  3|0.18998048357873532|
+---+-------------------+


df.select($"key", rand() as "rand").show
+---+------------------+
|key|              rand|
+---+------------------+
|  1|0.3422484248879837|
|  2|0.2301384925817671|
|  3|0.6959421970071372|
+---+------------------+

回答by leo9r

According to this post, the best solution is not to put the new scala.util.Randominside the map, nor completely outside (ie. in the driver code), but in an intermediate mapPartitionsWithIndex:

根据这篇文章,最好的解决方案不是将new scala.util.Random内部放在地图中,也不完全放在外部(即在驱动程序代码中),而是放在中间mapPartitionsWithIndex

import scala.util.Random
val myAppSeed = 91234
val newRDD = myRDD.mapPartitionsWithIndex { (indx, iter) =>
   val rand = new scala.util.Random(indx+myAppSeed)
   iter.map(x => (x, Array.fill(10)(rand.nextDouble)))
}

回答by Joshua David Lickteig

Using Spark Dataset API, perhaps for use in an accumulator:

使用 Spark Dataset API,也许用于累加器:

df.withColumn("_n", substring(rand(),3,4).cast("bigint"))