Python collect_list 通过保留基于另一个变量的顺序

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/46580253/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 17:42:45  来源:igfitidea点击:

collect_list by preserving order based on another variable

pythonapache-sparkpyspark

提问by Ravi

I am trying to create a new column of lists in Pyspark using a groupby aggregation on existing set of columns. An example input data frame is provided below:

我试图在 Pyspark 中使用现有列集上的 groupby 聚合创建一个新的列表列。下面提供了一个示例输入数据框:

------------------------
id | date        | value
------------------------
1  |2014-01-03   | 10 
1  |2014-01-04   | 5
1  |2014-01-05   | 15
1  |2014-01-06   | 20
2  |2014-02-10   | 100   
2  |2014-03-11   | 500
2  |2014-04-15   | 1500

The expected output is:

预期的输出是:

id | value_list
------------------------
1  | [10, 5, 15, 20]
2  | [100, 500, 1500]

The values within a list are sorted by the date.

列表中的值按日期排序。

I tried using collect_list as follows:

我尝试使用 collect_list 如下:

from pyspark.sql import functions as F
ordered_df = input_df.orderBy(['id','date'],ascending = True)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

But collect_list doesn't guarantee order even if I sort the input data frame by date before aggregation.

但是,即使我在聚合前按日期对输入数据框进行排序,collect_list 也不能​​保证顺序。

Could someone help on how to do aggregation by preserving the order based on a second (date) variable?

有人可以通过保留基于第二个(日期)变量的顺序来帮助如何进行聚合?

回答by TMichel

from pyspark.sql import functions as F
from pyspark.sql import Window

w = Window.partitionBy('id').orderBy('date')

sorted_list_df = input_df.withColumn(
            'sorted_list', F.collect_list('value').over(w)
        )\
        .groupBy('id')\
        .agg(F.max('sorted_list').alias('sorted_list'))

Windowexamples provided by users often don't really explain what is going on so let me dissect it for you.

Window用户提供的示例通常并不能真正解释正在发生的事情,所以让我为您剖析。

As you know, using collect_listtogether with groupBywill result in an unorderedlist of values. This is because depending on how your data is partitioned, Spark will append values to your list as soon as it finds a row in the group. The order then depends on how Spark plans your aggregation over the executors.

如您所知,collect_list与 with 一起使用groupBy将产生一个无序列表的值。这是因为根据数据的分区方式,Spark 会在找到组中的行后立即将值附加到您的列表中。然后顺序取决于 Spark 如何计划您对执行程序的聚合。

A Windowfunction allows you to control that situation, grouping rows by a certain value so you can perform an operation overeach of the resultant groups:

Window功能可以控制这种情况下,通过一定值分组的行,以便可以执行的操作over得到的各组:

w = Window.partitionBy('id').orderBy('date')
  • partitionBy- you want groups/partitions of rows with the same id
  • orderBy- you want each row in the group to be sorted by date
  • partitionBy- 您想要具有相同的行的组/分区 id
  • orderBy- 您希望组中的每一行按以下顺序排序 date

Once you have defined the scope of your Window - "rows with the same id, sorted by date" -, you can use it to perform an operation over it, in this case, a collect_list:

一旦您定义了 Window 的范围 - “具有相同id, 排序的行date” -,您可以使用它对其执行操作,在本例中为 a collect_list

F.collect_list('value').over(w)

At this point you created a new column sorted_listwith an ordered list of values, sorted by date, but you still have duplicated rows per id. To trim out the duplicated rows you want to groupByidand keep the maxvalue in for each group:

此时,您创建了一个新列,sorted_list其中包含按日期排序的有序值列表,但每个id. 要修剪出您想要的重复行groupByid并保留max每个组的值:

.groupBy('id')\
.agg(F.max('sorted_list').alias('sorted_list'))

回答by mtoto

If you collect both dates and values as a list, you can sort the resulting column according to date using and udf, and then keep only the values in the result.

如果您将日期和值收集为列表,则可以使用 和 根据日期对结果列进行排序udf,然后仅保留结果中的值。

import operator
import pyspark.sql.functions as F

# create list column
grouped_df = input_df.groupby("id") \
               .agg(F.collect_list(F.struct("date", "value")) \
               .alias("list_col"))

# define udf
def sorter(l):
  res = sorted(l, key=operator.itemgetter(0))
  return [item[1] for item in res]

sort_udf = F.udf(sorter)

# test
grouped_df.select("id", sort_udf("list_col") \
  .alias("sorted_list")) \
  .show(truncate = False)
+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+

回答by Artavazd Balayan

The question was for PySpark but might be helpful to have it also for Scala Spark.

这个问题是针对 PySpark 的,但也可能对 Scala Spark 有帮助。

Let's prepare test dataframe:

让我们准备测试数据帧:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{ Window, UserDefinedFunction}

import java.sql.Date
import java.time.LocalDate

val spark: SparkSession = ...

// Out test data set
val data: Seq[(Int, Date, Int)] = Seq(
  (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10),
  (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5),
  (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15),
  (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20),
  (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100),
  (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500),
  (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500)
)

// Create dataframe
val df: DataFrame = spark.createDataFrame(data)
  .toDF("id", "date", "value")
df.show()
//+---+----------+-----+
//| id|      date|value|
//+---+----------+-----+
//|  1|2014-01-03|   10|
//|  1|2014-01-04|    5|
//|  1|2014-01-05|   15|
//|  1|2014-01-06|   20|
//|  2|2014-02-10|  100|
//|  2|2014-02-11|  500|
//|  2|2014-02-15| 1500|
//+---+----------+-----+

Use UDF

使用 UDF

// Group by id and aggregate date and value to new column date_value
val grouped = df.groupBy(col("id"))
  .agg(collect_list(struct("date", "value")) as "date_value")
grouped.show()
grouped.printSchema()
// +---+--------------------+
// | id|          date_value|
// +---+--------------------+
// |  1|[[2014-01-03,10],...|
// |  2|[[2014-02-10,100]...|
// +---+--------------------+

// udf to extract data from Row, sort by needed column (date) and return value
val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => {
  rows.map { case Row(date: Date, value: Int) => (date, value) }
    .sortBy { case (date, value) => date }
    .map { case (date, value) => value }
})

// Select id and value_list
val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list"))
r1.show()
// +---+----------------+
// | id|      value_list|
// +---+----------------+
// |  1| [10, 5, 15, 20]|
// |  2|[100, 500, 1500]|
// +---+----------------+

Use Window

使用窗口

val window = Window.partitionBy(col("id")).orderBy(col("date"))
val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window))
sortedDf.show()
//+---+----------+-----+---------------------+
//| id|      date|value|values_sorted_by_date|
//+---+----------+-----+---------------------+
//|  1|2014-01-03|   10|                 [10]|
//|  1|2014-01-04|    5|              [10, 5]|
//|  1|2014-01-05|   15|          [10, 5, 15]|
//|  1|2014-01-06|   20|      [10, 5, 15, 20]|
//|  2|2014-02-10|  100|                [100]|
//|  2|2014-02-11|  500|           [100, 500]|
//|  2|2014-02-15| 1500|     [100, 500, 1500]|
//+---+----------+-----+---------------------+

val r2 = sortedDf.groupBy(col("id"))
  .agg(max("values_sorted_by_date").as("value_list")) 
r2.show()
//+---+----------------+
//| id|      value_list|
//+---+----------------+
//|  1| [10, 5, 15, 20]|
//|  2|[100, 500, 1500]|
//+---+----------------+

回答by ShadyStego

To make sure the sort is done for each id, we can use sortWithinPartitions:

为了确保对每个 id 进行排序,我们可以使用 sortWithinPartitions:

from pyspark.sql import functions as F
ordered_df = (
    input_df
        .repartition(input_df.id)
        .sortWithinPartitions(['date'])


)
grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))

回答by nvarelas

I tried TMichel approach and didn't work for me. When I performed the max aggregation I wasn't getting back the highest value of the list. So what worked for me is the following:

我尝试了 TMichel 方法,但对我不起作用。当我执行最大聚合时,我没有取回列表的最高值。所以对我有用的是以下内容:

def max_n_values(df, key, col_name, number):
    '''
    Returns the max n values of a spark dataframe
    partitioned by the key and ranked by the col_name
    '''
    w2 = Window.partitionBy(key).orderBy(f.col(col_name).desc())
    output = df.select('*',
                       f.row_number().over(w2).alias('rank')).filter(
                           f.col('rank') <= number).drop('rank')
    return output

def col_list(df, key, col_to_collect, name, score):
    w = Window.partitionBy(key).orderBy(f.col(score).desc())

    list_df = df.withColumn(name, f.collect_set(col_to_collect).over(w))
    size_df = list_df.withColumn('size', f.size(name))
    output = max_n_values(df=size_df,
                               key=key,
                               col_name='size',
                               number=1)
    return output

回答by jxc

As of Spark 2.4, the collect_list(ArrayType) created in @mtoto's answer can be post-processed by using SparkSQL's builtin functions transformand array_sort(no need for udf):

从 Spark 2.4 开始,@mtoto 的答案中创建的 collect_list(ArrayType) 可以使用 SparkSQL 的内置函数transformarray_sort(不需要udf)进行后处理:

from pyspark.sql.functions import collect_list, expr, struct

df.groupby('id') \
  .agg(collect_list(struct('date','value')).alias('value_list')) \
  .withColumn('value_list', expr('transform(array_sort(value_list), x -> x.value)')) \
  .show()
+---+----------------+
| id|      value_list|
+---+----------------+
|  1| [10, 5, 15, 20]|
|  2|[100, 500, 1500]|
+---+----------------+ 

Note:if descending order is required change array_sort(value_list)to sort_array(value_list, False)

注意:如果需要降序,更改array_sort(value_list)sort_array(value_list, False)

Caveat:array_sort() and sort_array() won't work if items(in collect_list) must be sorted by multiple fields(columns) in a mixed order, i.e. orderBy('col1', desc('col2')).

警告:如果项目(在 collect_list 中)必须按混合顺序按多个字段(列)排序,则 array_sort() 和 sort_array() 将不起作用,即orderBy('col1', desc('col2')).

回答by KARTHICK JOTHIMANI

You can use sort_array function. If you collect both dates and values as a list, you can sort the resulting column using sorry_array and keep only the columns you require.

您可以使用 sort_array 函数。如果您将日期和值收集为一个列表,您可以使用 sorry_array 对结果列进行排序,并仅保留您需要的列。

import operator
import pyspark.sql.functions as F

grouped_df = input_df.groupby("id") \
               .agg(F.sort_array(F.collect_list(F.struct("date", "value"))) \
.alias("collected_list")) \
.withColumn("sorted_list",col("collected_list.value")) \
.drop("collected_list")
.show(truncate=False)

+---+----------------+
|id |sorted_list     |
+---+----------------+
|1  |[10, 5, 15, 20] |
|2  |[100, 500, 1500]|
+---+----------------+ ```````

回答by kubote

Complementing what ShadyStegosaid, I've been testing the usage of sortWithinPartitions and GroupBy on Spark, finding out it performs quite better than Window functions or UDF. Still, there is an issue with a missordering once per partition when using this method, but it can be easily solved. I show it here Spark (pySpark) groupBy misordering first element on collect_list.

补充ShadyStego所说的,我一直在测试 sortWithinPartitions 和 GroupBy 在 Spark 上的使用,发现它的性能比窗口函数或 UDF 好得多。尽管如此,在使用这种方法时,每个分区都会出现一次错误排序的问题,但可以轻松解决。我在这里展示了Spark (pySpark) groupBy 对 collect_list 上的第一个元素进行了错误排序

This method is specially useful on large DataFrames, but a large number of partitions may be needed if you are short on driver memory.

此方法在大型 DataFrame 上特别有用,但如果驱动程序内存不足,则可能需要大量分区。