scala Spark 在 groupBy/aggregate 中合并/组合数组
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/39496517/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Spark merge/combine arrays in groupBy/aggregate
提问by clay
The following Spark code correctly demonstrates what I want to do and generates the correct output with a tiny demo data set.
以下 Spark 代码正确演示了我想要做的事情,并使用一个很小的演示数据集生成正确的输出。
When I run this same general type of code on a large volume of production data, I am having runtime problems. The Spark job runs on my cluster for ~12 hours and fails out.
当我在大量生产数据上运行相同类型的代码时,我遇到了运行时问题。Spark 作业在我的集群上运行了大约 12 个小时,但失败了。
Just glancing at the code below, it seems inefficient to explode every row, just to merge it back down. In the given test data set, the fourth row with three values in array_value_1 and three values in array_value_2, that will explode to 3*3 or nine exploded rows.
看看下面的代码,炸开每一行似乎效率很低,只是把它合并回来。在给定的测试数据集中,array_value_1 中的三个值和 array_value_2 中的三个值的第四行将分解为 3*3 或九个分解行。
So, in a larger data set, a row with five such array columns, and ten values in each column, would explode out to 10^5 exploded rows?
那么,在一个更大的数据集中,一行有五个这样的数组列,每列有十个值,会爆炸成 10^5 个爆炸行吗?
Looking at the provided Spark functions, there are no out of the box functions that would do what I want. I could supply a user-defined-function. Are there any speed drawbacks to that?
查看提供的 Spark 函数,没有现成的函数可以满足我的需求。我可以提供一个用户定义的函数。是否有任何速度缺点?
val sparkSession = SparkSession.builder.
master("local")
.appName("merge list test")
.getOrCreate()
val schema = StructType(
StructField("category", IntegerType) ::
StructField("array_value_1", ArrayType(StringType)) ::
StructField("array_value_2", ArrayType(StringType)) ::
Nil)
val rows = List(
Row(1, List("a", "b"), List("u", "v")),
Row(1, List("b", "c"), List("v", "w")),
Row(2, List("c", "d"), List("w")),
Row(2, List("c", "d", "e"), List("x", "y", "z"))
)
val df = sparkSession.createDataFrame(rows.asJava, schema)
val dfExploded = df.
withColumn("scalar_1", explode(col("array_value_1"))).
withColumn("scalar_2", explode(col("array_value_2")))
// This will output 19. 2*2 + 2*2 + 2*1 + 3*3 = 19
logger.info(s"dfExploded.count()=${dfExploded.count()}")
val dfOutput = dfExploded.groupBy("category").agg(
collect_set("scalar_1").alias("combined_values_2"),
collect_set("scalar_2").alias("combined_values_2"))
dfOutput.show()
回答by zero323
It could be inefficient to explodebut fundamentally the operation you try to implement is simply expensive. Effectively it is just another groupByKeyand there is not much you can do here to make it better. Since you use Spark > 2.0 you could collect_listdirectly and flatten:
这可能效率低下,explode但从根本上说,您尝试实施的操作非常昂贵。实际上,它只是另一个groupByKey,在这里您无法做太多事情来使它变得更好。由于您使用 Spark > 2.0,您可以collect_list直接和展平:
import org.apache.spark.sql.functions.{collect_list, udf}
val flatten_distinct = udf(
(xs: Seq[Seq[String]]) => xs.flatten.distinct)
df
.groupBy("category")
.agg(
flatten_distinct(collect_list("array_value_1")),
flatten_distinct(collect_list("array_value_2"))
)
In Spark >= 2.4 you can replace udf with composition of built-in functions:
在 Spark >= 2.4 中,您可以用内置函数的组合替换 udf:
import org.apache.spark.sql.functions.{array_distinct, flatten}
val flatten_distinct = (array_distinct _) compose (flatten _)
It is also possible to use custom Aggregatorbut I doubt any of these will make a huge difference.
也可以使用自定义,Aggregator但我怀疑其中任何一个都会产生巨大的差异。
If sets are relatively large and you expect significant number of duplicates you could try to use aggregateByKeywith mutable sets:
如果集合相对较大并且您希望有大量重复项,则可以尝试使用aggregateByKey可变集合:
import scala.collection.mutable.{Set => MSet}
val rdd = df
.select($"category", struct($"array_value_1", $"array_value_2"))
.as[(Int, (Seq[String], Seq[String]))]
.rdd
val agg = rdd
.aggregateByKey((MSet[String](), MSet[String]()))(
{case ((accX, accY), (xs, ys)) => (accX ++= xs, accY ++ ys)},
{case ((accX1, accY1), (accX2, accY2)) => (accX1 ++= accX2, accY1 ++ accY2)}
)
.mapValues { case (xs, ys) => (xs.toArray, ys.toArray) }
.toDF

