SQL 如何选择每组的第一行?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/33878370/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to select the first row of each group?
提问by Rami
I have a DataFrame generated as follow:
我有一个 DataFrame 生成如下:
df.groupBy($"Hour", $"Category")
.agg(sum($"value") as "TotalValue")
.sort($"Hour".asc, $"TotalValue".desc))
The results look like:
结果如下:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 0| cat13| 22.1|
| 0| cat95| 19.6|
| 0| cat105| 1.3|
| 1| cat67| 28.5|
| 1| cat4| 26.8|
| 1| cat13| 12.6|
| 1| cat23| 5.3|
| 2| cat56| 39.6|
| 2| cat40| 29.7|
| 2| cat187| 27.9|
| 2| cat68| 9.8|
| 3| cat8| 35.6|
| ...| ....| ....|
+----+--------+----------+
As you can see, the DataFrame is ordered by Hour
in an increasing order, then by TotalValue
in a descending order.
如您所见,DataFrame按升序排列Hour
,然后按TotalValue
降序排列。
I would like to select the top row of each group, i.e.
我想选择每组的顶行,即
- from the group of Hour==0 select (0,cat26,30.9)
- from the group of Hour==1 select (1,cat67,28.5)
- from the group of Hour==2 select (2,cat56,39.6)
- and so on
- 从 Hour==0 组中选择 (0,cat26,30.9)
- 从 Hour==1 组中选择 (1,cat67,28.5)
- 从 Hour==2 组中选择 (2,cat56,39.6)
- 等等
So the desired output would be:
所以期望的输出是:
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
| ...| ...| ...|
+----+--------+----------+
It might be handy to be able to select the top N rows of each group as well.
能够选择每个组的前 N 行也可能很方便。
Any help is highly appreciated.
任何帮助都受到高度赞赏。
回答by zero323
Window functions:
窗口功能:
Something like this should do the trick:
像这样的事情应该可以解决问题:
import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
This method will be inefficient in case of significant data skew.
在数据严重倾斜的情况下,此方法将效率低下。
Plain SQL aggregation followed by join
:
普通 SQL 聚合后跟join
:
Alternatively you can join with aggregated data frame:
或者,您可以加入聚合数据框:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:
它将保留重复值(如果每小时有多个类别具有相同的总值)。您可以按如下方式删除这些:
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
Using ordering over structs
:
使用排序structs
:
Neat, although not very well tested, trick which doesn't require joins or window functions:
整洁,虽然没有经过很好的测试,但不需要连接或窗口函数的技巧:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
With DataSet API(Spark 1.6+, 2.0+):
使用数据集 API(Spark 1.6+、2.0+):
Spark 1.6:
火花1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0 or later:
Spark 2.0 或更高版本:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed
output mode.
后两种方法可以利用 map side combine 并且不需要 full shuffle,因此与窗口函数和 join 相比,大多数时间应该表现出更好的性能。这些 Cane 也可与completed
输出模式下的结构化流媒体一起使用。
Don't use:
不要使用:
df.orderBy(...).groupBy(...).agg(first(...), ...)
It may seem to work (especially in the local
mode) but it is unreliable (see SPARK-16207, credits to Tzach Zoharfor linking relevant JIRA issue, and SPARK-30335).
这似乎是工作(尤其是在local
模式),但它是不可靠的(见SPARK-16207,学分Tzach琐的链接有关JIRA问题,和SPARK-30335)。
The same note applies to
同样的注意事项适用于
df.orderBy(...).dropDuplicates(...)
which internally uses equivalent execution plan.
在内部使用等效的执行计划。
回答by Antonín Hoskovec
For Spark 2.0.2 with grouping by multiple columns:
对于按多列分组的 Spark 2.0.2:
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
回答by Ramesh Maharjan
This is a exact same of zero323's answerbut in SQL query way.
这与zero323的答案完全相同,但采用 SQL 查询方式。
Assuming that dataframe is created and registered as
假设数据帧被创建并注册为
df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0 |cat26 |30.9 |
//|0 |cat13 |22.1 |
//|0 |cat95 |19.6 |
//|0 |cat105 |1.3 |
//|1 |cat67 |28.5 |
//|1 |cat4 |26.8 |
//|1 |cat13 |12.6 |
//|1 |cat23 |5.3 |
//|2 |cat56 |39.6 |
//|2 |cat40 |29.7 |
//|2 |cat187 |27.9 |
//|2 |cat68 |9.8 |
//|3 |cat8 |35.6 |
//+----+--------+----------+
Window function :
窗函数:
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Plain SQL aggregation followed by join:
普通 SQL 聚合后跟连接:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
"(select Hour, Category, TotalValue from table tmp1 " +
"join " +
"(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
"on " +
"tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
"group by tmp3.Hour")
.show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Using ordering over structs:
在结构上使用排序:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
DataSets wayand don't dos are same as in original answer
DataSets 方式和不做s 与原始答案相同
回答by Rubber Duck
The pattern is group by keys => do something to each group e.g. reduce => return to dataframe
该模式按键分组 => 对每个组做一些事情,例如减少 => 返回数据帧
I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality
我认为在这种情况下 Dataframe 抽象有点麻烦所以我使用了 RDD 功能
val rdd: RDD[Row] = originalDf
.rdd
.groupBy(row => row.getAs[String]("grouping_row"))
.map(iterableTuple => {
iterableTuple._2.reduce(reduceFunction)
})
val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
回答by elghoto
The solution below does only one groupBy and extract the rows of your dataframe that contain the maxValue in one shot. No need for further Joins, or Windows.
下面的解决方案只执行一个 groupBy 并一次性提取包含 maxValue 的数据帧行。无需进一步加入或 Windows。
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame
//df is the dataframe with Day, Category, TotalValue
implicit val dfEnc = RowEncoder(df.schema)
val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
回答by randal25
A nice way of doing this with the dataframe api is using the argmax logic like so
使用 dataframe api 执行此操作的一个好方法是使用 argmax 逻辑,如下所示
val df = Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")
df.groupBy($"Hour")
.agg(max(struct($"TotalValue", $"Category")).as("argmax"))
.select($"Hour", $"argmax.*").show
+----+----------+--------+
|Hour|TotalValue|Category|
+----+----------+--------+
| 1| 28.5| cat67|
| 3| 35.6| cat8|
| 2| 39.6| cat56|
| 0| 30.9| cat26|
+----+----------+--------+
回答by Shubham Agrawal
Here you can do like this -
在这里你可以这样做 -
val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")
data.withColumnRenamed("_1","Hour").show
回答by Vasile Surdu
We can use the rank() window function (where you would choose the rank = 1) rank just adds a number for every row of a group (in this case it would be the hour)
我们可以使用 rank() 窗口函数(您可以选择 rank = 1) rank 只是为组的每一行添加一个数字(在本例中为小时)
here's an example. ( from https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank)
val dataset = spark.range(9).withColumn("bucket", 'id % 3)
import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)
scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
| 0| 0| 1|
| 3| 0| 2|
| 6| 0| 3|
| 1| 1| 1|
| 4| 1| 2|
| 7| 1| 3|
| 2| 2| 1|
| 5| 2| 2|
| 8| 2| 3|
+---+------+----+