scala 如何在 Spark SQL 中定义和使用用户定义的聚合函数?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/32100973/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-10-22 07:29:27  来源:igfitidea点击:

How to define and use a User-Defined Aggregate Function in Spark SQL?

scalaapache-sparkapache-spark-sqlaggregate-functionsuser-defined-functions

提问by Rory Byrne

I know how to write a UDF in Spark SQL:

我知道如何在 Spark SQL 中编写 UDF:

def belowThreshold(power: Int): Boolean = {
        return power < -40
      }

sqlContext.udf.register("belowThreshold", belowThreshold _)

Can I do something similar to define an aggregate function? How is this done?

我可以做一些类似的事情来定义聚合函数吗?这是怎么做的?

For context, I want to run the following SQL query:

对于上下文,我想运行以下 SQL 查询:

val aggDF = sqlContext.sql("""SELECT span, belowThreshold(opticalReceivePower), timestamp
                                    FROM ifDF
                                    WHERE opticalReceivePower IS NOT null
                                    GROUP BY span, timestamp
                                    ORDER BY span""")

It should return something like

它应该返回类似

Row(span1, false, T0)

Row(span1, false, T0)

I want the aggregate function to tell me if there's any values for opticalReceivePowerin the groups defined by spanand timestampwhich are below the threshold. Do I need to write my UDAF differently to the UDF I pasted above?

我希望聚合函数告诉我opticalReceivePower在由span和定义的组中是否有任何值timestamp低于阈值。我是否需要以与上面粘贴的 UDF 不同的方式编写我的 UDAF?

回答by zero323

Supported methods

支持的方法

Spark >= 3.0

火花 >= 3.0

Scala UserDefinedAggregateFunctionis being deprecated (SPARK-30423Deprecate UserDefinedAggregateFunction) in favor of registered Aggregator.

ScalaUserDefinedAggregateFunction正在被弃用 ( SPARK-30423 Deprecate UserDefinedAggregateFunction) 以支持注册Aggregator.

Spark >= 2.3

火花 >= 2.3

Vectorized udf (Python only):

矢量化 udf(仅限 Python):

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType

from pyspark.sql.types import *
import pandas as pd

df = sc.parallelize([
    ("a", 0), ("a", 1), ("b", 30), ("b", -50)
]).toDF(["group", "power"])

def below_threshold(threshold, group="group", power="power"):
    @pandas_udf("struct<group: string, below_threshold: boolean>", PandasUDFType.GROUPED_MAP)
    def below_threshold_(df):
        df = pd.DataFrame(
           df.groupby(group).apply(lambda x: (x[power] < threshold).any()))
        df.reset_index(inplace=True, drop=False)
        return df

    return below_threshold_

Example usage:

用法示例:

df.groupBy("group").apply(below_threshold(-40)).show()

## +-----+---------------+
## |group|below_threshold|
## +-----+---------------+
## |    b|           true|
## |    a|          false|
## +-----+---------------+

See also Applying UDFs on GroupedData in PySpark (with functioning python example)

另请参阅在 PySpark 中的 GroupedData 上应用 UDF(带有运行的 Python 示例)

Spark >= 2.0(optionally 1.6 but with slightly different API):

Spark >= 2.0(可选 1.6,但 API 略有不同):

It is possible to use Aggregatorson typed Datasets:

可以Aggregators在 typed上使用Datasets

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.{Encoder, Encoders}

class BelowThreshold[I](f: I => Boolean)  extends Aggregator[I, Boolean, Boolean]
    with Serializable {
  def zero = false
  def reduce(acc: Boolean, x: I) = acc | f(x)
  def merge(acc1: Boolean, acc2: Boolean) = acc1 | acc2
  def finish(acc: Boolean) = acc

  def bufferEncoder: Encoder[Boolean] = Encoders.scalaBoolean
  def outputEncoder: Encoder[Boolean] = Encoders.scalaBoolean
}

val belowThreshold = new BelowThreshold[(String, Int)](_._2 < - 40).toColumn
df.as[(String, Int)].groupByKey(_._1).agg(belowThreshold)

Spark >= 1.5:

火花 >= 1.5

In Spark 1.5 you can create UDAF like this although it is most likely an overkill:

在 Spark 1.5 中,您可以像这样创建 UDAF,尽管这很可能是一种矫枉过正:

import org.apache.spark.sql.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row

object belowThreshold extends UserDefinedAggregateFunction {
    // Schema you get as an input
    def inputSchema = new StructType().add("power", IntegerType)
    // Schema of the row which is used for aggregation
    def bufferSchema = new StructType().add("ind", BooleanType)
    // Returned type
    def dataType = BooleanType
    // Self-explaining 
    def deterministic = true
    // zero value
    def initialize(buffer: MutableAggregationBuffer) = buffer.update(0, false)
    // Similar to seqOp in aggregate
    def update(buffer: MutableAggregationBuffer, input: Row) = {
        if (!input.isNullAt(0))
          buffer.update(0, buffer.getBoolean(0) | input.getInt(0) < -40)
    }
    // Similar to combOp in aggregate
    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
      buffer1.update(0, buffer1.getBoolean(0) | buffer2.getBoolean(0))    
    }
    // Called on exit to get return value
    def evaluate(buffer: Row) = buffer.getBoolean(0)
}

Example usage:

用法示例:

df
  .groupBy($"group")
  .agg(belowThreshold($"power").alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark 1.4 workaround:

Spark 1.4 解决方法

I am not sure if I correctly understand your requirements but as far as I can tell plain old aggregation should be enough here:

我不确定我是否正确理解了您的要求,但据我所知,简单的旧聚合在这里应该足够了:

val df = sc.parallelize(Seq(
    ("a", 0), ("a", 1), ("b", 30), ("b", -50))).toDF("group", "power")

df
  .withColumn("belowThreshold", ($"power".lt(-40)).cast(IntegerType))
  .groupBy($"group")
  .agg(sum($"belowThreshold").notEqual(0).alias("belowThreshold"))
  .show

// +-----+--------------+
// |group|belowThreshold|
// +-----+--------------+
// |    a|         false|
// |    b|          true|
// +-----+--------------+

Spark <= 1.4:

火花 <= 1.4

As far I know, at this moment (Spark 1.4.1), there is no support for UDAF, other than the Hive ones. It should be possible with Spark 1.5 (see SPARK-3947).

据我所知,目前(Spark 1.4.1),除了 Hive 之外,不支持 UDAF。Spark 1.5 应该可以实现(参见SPARK-3947)。

Unsupported / internal methods

不支持/内部方法

Internally Spark uses a number of classes including ImperativeAggregatesand DeclarativeAggregates.

Spark 内部使用了许多类,包括ImperativeAggregatesDeclarativeAggregates

There are intended for internal usage and may change without further notice, so it is probably not something you want to use in your production code, but just for completeness BelowThresholdwith DeclarativeAggregatecould be implemented like this (tested with Spark 2.2-SNAPSHOT):

用于内部使用,可能会更改,恕不另行通知,因此您可能不想在生产代码中使用它,但只是为了完整性BelowThresholdDeclarativeAggregate可以像这样实现(使用 Spark 2.2-SNAPSHOT 测试):

import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._

case class BelowThreshold(child: Expression, threshold: Expression) 
    extends  DeclarativeAggregate  {
  override def children: Seq[Expression] = Seq(child, threshold)

  override def nullable: Boolean = false
  override def dataType: DataType = BooleanType

  private lazy val belowThreshold = AttributeReference(
    "belowThreshold", BooleanType, nullable = false
  )()

  // Used to derive schema
  override lazy val aggBufferAttributes = belowThreshold :: Nil

  override lazy val initialValues = Seq(
    Literal(false)
  )

  override lazy val updateExpressions = Seq(Or(
    belowThreshold,
    If(IsNull(child), Literal(false), LessThan(child, threshold))
  ))

  override lazy val mergeExpressions = Seq(
    Or(belowThreshold.left, belowThreshold.right)
  )

  override lazy val evaluateExpression = belowThreshold
  override def defaultResult: Option[Literal] = Option(Literal(false))
} 

It should be further wrapped with an equivalent of withAggregateFunction.

它应该用等效的withAggregateFunction.