Python 在pyspark中检索每组DataFrame中的前n个

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/38397796/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 20:44:23  来源:igfitidea点击:

Retrieve top n in each group of a DataFrame in pyspark

pythonapache-sparkdataframepysparkapache-spark-sql

提问by KAs

There's a DataFrame in pyspark with data as below:

pyspark 中有一个 DataFrame,其数据如下:

user_id object_id score
user_1  object_1  3
user_1  object_1  1
user_1  object_2  2
user_2  object_1  5
user_2  object_2  2
user_2  object_2  6

What I expect is returning 2 records in each group with the same user_id, which need to have the highest score. Consequently, the result should look as the following:

我期望的是在每个组中返回 2 条具有相同 user_id 的记录,这些记录需要具有最高分。因此,结果应如下所示:

user_id object_id score
user_1  object_1  3
user_1  object_2  2
user_2  object_2  6
user_2  object_1  5

I'm really new to pyspark, could anyone give me a code snippet or portal to the related documentation of this problem? Great thanks!

我真的是 pyspark 的新手,谁能给我一个代码片段或这个问题的相关文档的门户?万分感谢!

回答by mtoto

I believe you need to use window functionsto attain the rank of each row based on user_idand score, and subsequently filter your results to only keep the first two values.

我相信您需要使用窗口函数来获得基于user_idand的每一行的排名score,然后过滤您的结果以仅保留前两个值。

from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col

window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())

df.select('*', rank().over(window).alias('rank')) 
  .filter(col('rank') <= 2) 
  .show() 
#+-------+---------+-----+----+
#|user_id|object_id|score|rank|
#+-------+---------+-----+----+
#| user_1| object_1|    3|   1|
#| user_1| object_2|    2|   2|
#| user_2| object_2|    6|   1|
#| user_2| object_1|    5|   2|
#+-------+---------+-----+----+

In general, the official programming guideis a good place to start learning Spark.

一般来说,官方编程指南是开始学习 Spark 的好地方。

Data

数据

rdd = sc.parallelize([("user_1",  "object_1",  3), 
                      ("user_1",  "object_2",  2), 
                      ("user_2",  "object_1",  5), 
                      ("user_2",  "object_2",  2), 
                      ("user_2",  "object_2",  6)])
df = sqlContext.createDataFrame(rdd, ["user_id", "object_id", "score"])

回答by Martin Tapp

Top-n is more accurate if using row_numberinstead of rankwhen getting rank equality:

如果使用row_number而不是rank在获得等级相等时使用 Top-n 更准确:

val n = 5
df.select(col('*'), row_number().over(window).alias('row_number')) \
  .where(col('row_number') <= n) \
  .limit(20) \
  .toPandas()

Note limit(20).toPandas()trick instead of show()for Jupyter notebooks for nicer formatting.

注意limit(20).toPandas()技巧而不是show()Jupyter 笔记本以获得更好的格式。

回答by Abu Shoeb

I know the question is asked for pysparkand I was looking for the similar answer in Scalai.e.

我知道这个问题是问的pyspark,我在Scalaie 中寻找类似的答案

Retrieve top n values in each group of a DataFrame in Scala

在Scala中检索DataFrame的每组中的前n个值

Here is the scalaversion of @mtoto's answer.

这是scala@mtoto 答案的版本。

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.rank
import org.apache.spark.sql.functions.col

val window = Window.partitionBy("user_id").orderBy('score desc)
val rankByScore = rank().over(window)
df1.select('*, rankByScore as 'rank).filter(col("rank") <= 2).show() 
# you can change the value 2 to any number you want. Here 2 represents the top 2 values

More examples can be found here.

可以在此处找到更多示例。

回答by Dean

To Find Nth highest value in PYSPARK SQLquery using ROW_NUMBER()function:

使用ROW_NUMBER()函数在 PYSPARK SQLquery 中查找第 N 个最大值:

SELECT * FROM (
    SELECT e.*, 
    ROW_NUMBER() OVER (ORDER BY col_name DESC) rn 
    FROM Employee e
)
WHERE rn = N

N is the nth highest value required from the column

N 是该列所需的第 n 个最高值

Output:

输出:

[Stage 2:>               (0 + 1) / 1]++++++++++++++++
+-----------+
|col_name   |
+-----------+
|1183395    |
+-----------+

query will return N highest value

查询将返回 N 个最高值

回答by Vivek

with Python 3 and Spark 2.4

使用 Python 3 和 Spark 2.4

from pyspark.sql import Window
import pyspark.sql.functions as f

def get_topN(df, group_by_columns, order_by_column, n=1):
    window_group_by_columns = Window.partitionBy(group_by_columns)
    ordered_df = df.select(df.columns + [
        f.row_number().over(window_group_by_columns.orderBy(order_by_column.desc())).alias('row_rank')])
    topN_df = ordered_df.filter(f"row_rank <= {n}").drop("row_rank")
    return topN_df

top_n_df = get_topN(your_dataframe, [group_by_columns],[order_by_columns], 1)