SQL 使用 Spark DataFrame groupby 时如何获取其他列?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/34409875/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-09-01 04:16:31  来源:igfitidea点击:

How to get other columns when using Spark DataFrame groupby?

sqlapache-sparkdataframeapache-spark-sql

提问by Psychevic

when I use DataFrame groupby like this:

当我像这样使用 DataFrame groupby 时:

df.groupBy(df("age")).agg(Map("id"->"count"))

I will only get a DataFrame with columns "age" and "count(id)",but in df,there are many other columns like "name".

我只会得到一个包含“age”和“count(id)”列的DataFrame,但在df中,还有许多其他列,如“name”。

In all,I want to get the result as in MySQL,

总之,我想得到 MySQL 中的结果,

"select name,age,count(id) from df group by age"

“按年龄从 df 组中选择名称、年龄、计数(id)”

What should I do when use groupby in Spark?

Spark中使用groupby时应该怎么做?

回答by zero323

Long story short in general you have to join aggregated results with the original table. Spark SQL follows the same pre-SQL:1999 convention as most of the major databases (PostgreSQL, Oracle, MS SQL Server) which doesn't allow additional columns in aggregation queries.

长话短说,一般来说,您必须将聚合结果与原始表连接起来。Spark SQL 遵循与大多数主要数据库(PostgreSQL、Oracle、MS SQL Server)相同的 pre-SQL:1999 约定,它不允许在聚合查询中添加额外的列。

Since for aggregations like count results are not well defined and behavior tends to vary in systems which supports this type of queries you can just include additional columns using arbitrary aggregate like firstor last.

由于像 count 结果这样的聚合没有很好地定义,并且在支持这种类型查询的系统中行为往往会有所不同,因此您可以使用任意聚合(如first或 )包含其他列last

In some cases you can replace aggusing selectwith window functions and subsequent wherebut depending on the context it can be quite expensive.

在某些情况下,您可以将aggusing替换select为窗口函数和后续函数,where但根据上下文,它可能会非常昂贵。

回答by Swetha Kannan

One way to get all columns after doing a groupBy is to use join function.

在执行 groupBy 后获取所有列的一种方法是使用 join 函数。

feature_group = ['name', 'age']
data_counts = df.groupBy(feature_group).count().alias("counts")
data_joined = df.join(data_counts, feature_group)

data_joined will now have all columns including the count values.

data_joined 现在将拥有包括计数值在内的所有列。

回答by Thirupathi Chavati

May be this solution will helpfull.

可能这个解决方案会有所帮助。

from pyspark.sql import SQLContext
from pyspark import SparkContext, SparkConf
from pyspark.sql import functions as F
from pyspark.sql import Window

    name_list = [(101, 'abc', 24), (102, 'cde', 24), (103, 'efg', 22), (104, 'ghi', 21),
                 (105, 'ijk', 20), (106, 'klm', 19), (107, 'mno', 18), (108, 'pqr', 18),
                 (109, 'rst', 26), (110, 'tuv', 27), (111, 'pqr', 18), (112, 'rst', 28), (113, 'tuv', 29)]

age_w = Window.partitionBy("age")
name_age_df = sqlContext.createDataFrame(name_list, ['id', 'name', 'age'])

name_age_count_df = name_age_df.withColumn("count", F.count("id").over(age_w)).orderBy("count")
name_age_count_df.show()

Output:

输出:

+---+----+---+-----+
| id|name|age|count|
+---+----+---+-----+
|109| rst| 26|    1|
|113| tuv| 29|    1|
|110| tuv| 27|    1|
|106| klm| 19|    1|
|103| efg| 22|    1|
|104| ghi| 21|    1|
|105| ijk| 20|    1|
|112| rst| 28|    1|
|101| abc| 24|    2|
|102| cde| 24|    2|
|107| mno| 18|    3|
|111| pqr| 18|    3|
|108| pqr| 18|    3|
+---+----+---+-----+

回答by Rubber Duck

Aggregate functions reduce values of rows for specified columns within the group. If you wish to retain other row values you need to implement reduction logic that specifies a row from which each value comes from. For instance keep all values of the first row with the maximum value of age. To this end you can use a UDAF (user defined aggregate function) to reduce rows within the group.

聚合函数减少组内指定列的行值。如果您希望保留其他行值,您需要实现指定每个值来自的行的归约逻辑。例如,将第一行的所有值保留为年龄的最大值。为此,您可以使用 UDAF(用户定义的聚合函数)来减少组内的行。

import org.apache.spark.sql._
import org.apache.spark.sql.functions._


object AggregateKeepingRowJob {

  def main (args: Array[String]): Unit = {

    val sparkSession = SparkSession
      .builder()
      .appName(this.getClass.getName.replace("$", ""))
      .master("local")
      .getOrCreate()

    val sc = sparkSession.sparkContext
    sc.setLogLevel("ERROR")

    import sparkSession.sqlContext.implicits._

    val rawDf = Seq(
      (1L, "Moe",  "Slap",  2.0, 18),
      (2L, "Larry",  "Spank",  3.0, 15),
      (3L, "Curly",  "Twist", 5.0, 15),
      (4L, "Laurel", "Whimper", 3.0, 15),
      (5L, "Hardy", "Laugh", 6.0, 15),
      (6L, "Charley",  "Ignore",   5.0, 5)
    ).toDF("id", "name", "requisite", "money", "age")

    rawDf.show(false)
    rawDf.printSchema

    val maxAgeUdaf = new KeepRowWithMaxAge

    val aggDf = rawDf
      .groupBy("age")
      .agg(
        count("id"),
        max(col("money")),
        maxAgeUdaf(
          col("id"),
          col("name"),
          col("requisite"),
          col("money"),
          col("age")).as("KeepRowWithMaxAge")
      )

    aggDf.printSchema
    aggDf.show(false)

  }


}

The UDAF:

乌达夫:

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

class KeepRowWithMaxAmt extends UserDefinedAggregateFunction {
// This is the input fields for your aggregate function.
override def inputSchema: org.apache.spark.sql.types.StructType =
  StructType(
    StructField("store", StringType) ::
    StructField("prod", StringType) ::
    StructField("amt", DoubleType) ::
    StructField("units", IntegerType) :: Nil
  )

// This is the internal fields you keep for computing your aggregate.
override def bufferSchema: StructType = StructType(
  StructField("store", StringType) ::
  StructField("prod", StringType) ::
  StructField("amt", DoubleType) ::
  StructField("units", IntegerType) :: Nil
)


// This is the output type of your aggregation function.
override def dataType: DataType =
  StructType((Array(
    StructField("store", StringType),
    StructField("prod", StringType),
    StructField("amt", DoubleType),
    StructField("units", IntegerType)
  )))

override def deterministic: Boolean = true

// This is the initial value for your buffer schema.
override def initialize(buffer: MutableAggregationBuffer): Unit = {
  buffer(0) = ""
  buffer(1) = ""
  buffer(2) = 0.0
  buffer(3) = 0
}

// This is how to update your buffer schema given an input.
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {

  val amt = buffer.getAs[Double](2)
  val candidateAmt = input.getAs[Double](2)

  amt match {
    case a if a < candidateAmt =>
      buffer(0) = input.getAs[String](0)
      buffer(1) = input.getAs[String](1)
      buffer(2) = input.getAs[Double](2)
      buffer(3) = input.getAs[Int](3)
    case _ =>
  }
}

// This is how to merge two objects with the bufferSchema type.
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {

  buffer1(0) = buffer2.getAs[String](0)
  buffer1(1) = buffer2.getAs[String](1)
  buffer1(2) = buffer2.getAs[Double](2)
  buffer1(3) = buffer2.getAs[Int](3)
}

// This is where you output the final value, given the final value of your bufferSchema.
override def evaluate(buffer: Row): Any = {
  buffer
}
}

回答by Mohamed Hosni

Here an example that I came across in spark-workshop

这是我在spark-workshop 中遇到的一个例子

val populationDF = spark.read
                .option("infer-schema", "true")
                .option("header", "true")
                .format("csv").load("file:///databricks/driver/population.csv")
                .select('name, regexp_replace(col("population"), "\s", "").cast("integer").as("population"))

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

val maxPopulationDF = populationDF.agg(max('population).as("populationmax"))

To get other columns, I do a simple join between the original DF and the aggregated one

为了获得其他列,我在原始 DF 和聚合的 DF 之间做了一个简单的连接

populationDF.join(maxPopulationDF,populationDF.col("population") === maxPopulationDF.col("populationmax")).select('name, 'populationmax).show()

回答by Rubber Duck

You need to remember that aggregate functions reduce the rows and therefore you need to specify which of the rows name you want with a reducing function. If you want to retain all rows of a group (warning! this can cause explosions or skewed partitions) you can collect them as a list. You can then use a UDF (user defined function) to reduce them by your criteria, in my example money. And then expand columns from the single reduced row with another UDF . For the purpose of this answer I assume you wish to retain the name of the person who has the most money.

您需要记住聚合函数会减少行,因此您需要使用减少函数指定您想要的行名称。如果您想保留一个组的所有行(警告!这可能会导致爆炸或分区倾斜),您可以将它们收集为一个列表。然后,您可以使用 UDF(用户定义函数)根据您的标准来减少它们,在我的示例中为货币。然后使用另一个 UDF 从单个减少的行中扩展列。出于这个答案的目的,我假设您希望保留最有钱的人的名字。

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.StringType

import scala.collection.mutable


object TestJob3 {

def main (args: Array[String]): Unit = {

val sparkSession = SparkSession
  .builder()
  .appName(this.getClass.getName.replace("$", ""))
  .master("local")
  .getOrCreate()

val sc = sparkSession.sparkContext

import sparkSession.sqlContext.implicits._

val rawDf = Seq(
  (1, "Moe",  "Slap",  2.0, 18),
  (2, "Larry",  "Spank",  3.0, 15),
  (3, "Curly",  "Twist", 5.0, 15),
  (4, "Laurel", "Whimper", 3.0, 9),
  (5, "Hardy", "Laugh", 6.0, 18),
  (6, "Charley",  "Ignore",   5.0, 5)
).toDF("id", "name", "requisite", "money", "age")

rawDf.show(false)
rawDf.printSchema

val rawSchema = rawDf.schema

val fUdf = udf(reduceByMoney, rawSchema)

val nameUdf = udf(extractName, StringType)

val aggDf = rawDf
  .groupBy("age")
  .agg(
    count(struct("*")).as("count"),
    max(col("money")),
    collect_list(struct("*")).as("horizontal")
  )
  .withColumn("short", fUdf($"horizontal"))
  .withColumn("name", nameUdf($"short"))
  .drop("horizontal")

aggDf.printSchema

aggDf.show(false)

}

def reduceByMoney= (x: Any) => {

val d = x.asInstanceOf[mutable.WrappedArray[GenericRowWithSchema]]

val red = d.reduce((r1, r2) => {

  val money1 = r1.getAs[Double]("money")
  val money2 = r2.getAs[Double]("money")

  val r3 = money1 match {
    case a if a >= money2 =>
      r1
    case _ =>
      r2
  }

  r3
})

red
}

def extractName = (x: Any) => {

  val d = x.asInstanceOf[GenericRowWithSchema]

  d.getAs[String]("name")
}
}

here is the output

这是输出

+---+-----+----------+----------------------------+-------+
|age|count|max(money)|short                       |name   |
+---+-----+----------+----------------------------+-------+
|5  |1    |5.0       |[6, Charley, Ignore, 5.0, 5]|Charley|
|15 |2    |5.0       |[3, Curly, Twist, 5.0, 15]  |Curly  |
|9  |1    |3.0       |[4, Laurel, Whimper, 3.0, 9]|Laurel |
|18 |2    |6.0       |[5, Hardy, Laugh, 6.0, 18]  |Hardy  |
+---+-----+----------+----------------------------+-------+

回答by Sree Eedupuganti

You can do like this :

你可以这样做:

Sample data:

样本数据:

name    age id
abc     24  1001
cde     24  1002
efg     22  1003
ghi     21  1004
ijk     20  1005
klm     19  1006
mno     18  1007
pqr     18  1008
rst     26  1009
tuv     27  1010
pqr     18  1012
rst     28  1013
tuv     29  1011
df.select("name","age","id").groupBy("name","age").count().show();
df.select("name","age","id").groupBy("name","age").count().show();

Output:

输出:

    +----+---+-----+
    |name|age|count|
    +----+---+-----+
    | efg| 22|    1|
    | tuv| 29|    1|
    | rst| 28|    1|
    | klm| 19|    1|
    | pqr| 18|    2|
    | cde| 24|    1|
    | tuv| 27|    1|
    | ijk| 20|    1|
    | abc| 24|    1|
    | mno| 18|    1|
    | ghi| 21|    1|
    | rst| 26|    1|
    +----+---+-----+