scala 计算 Spark DataFrame 中分组数据的标准差

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/31789939/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-10-22 07:26:04  来源:igfitidea点击:

Calculate the standard deviation of grouped data in a Spark DataFrame

scalaapache-sparkapache-spark-sql

提问by the3rdNotch

I have user logs that I have taken from a csv and converted into a DataFrame in order to leverage the SparkSQL querying features. A single user will create numerous entries per hour, and I would like to gather some basic statistical information for each user; really just the count of the user instances, the average, and the standard deviation of numerous columns. I was able to quickly get the mean and count information by using groupBy($"user") and the aggregator with SparkSQL functions for count and avg:

我有从 csv 中获取并转换为 DataFrame 的用户日志,以便利用 SparkSQL 查询功能。单个用户每小时会创建很多条目,我想为每个用户收集一些基本的统计信息;实际上只是用户实例的计数、平均值和众多列的标准偏差。通过使用 groupBy($"user") 和带有 SparkSQL 函数的聚合器,我能够快速获得平均值和计数信息:

val meanData = selectedData.groupBy($"user").agg(count($"logOn"),
avg($"transaction"), avg($"submit"), avg($"submitsPerHour"), avg($"replies"),
avg($"repliesPerHour"), avg($"duration"))

However, I cannot seem to find an equally elegant way to calculate the standard deviation. So far I can only calculate it by mapping a string, double pair and use StatCounter().stdev utility:

但是,我似乎找不到一种同样优雅的方法来计算标准偏差。到目前为止,我只能通过映射字符串、双对并使用 StatCounter().stdev 实用程序来计算它:

val stdevduration = duration.groupByKey().mapValues(value =>
org.apache.spark.util.StatCounter(value).stdev)

This returns an RDD however, and I would like to try and keep it all in a DataFrame for further queries to be possible on the returned data.

但是,这会返回一个 RDD,我想尝试将其全部保存在 DataFrame 中,以便对返回的数据进行进一步查询。

回答by zero323

Spark 1.6+

火花 1.6+

You can use stddev_popto compute population standard deviation and stddev/ stddev_sampto compute unbiased sample standard deviation:

您可以使用stddev_pop计算总体标准差和stddev/stddev_samp来计算无偏样本标准差:

import org.apache.spark.sql.functions.{stddev_samp, stddev_pop}

selectedData.groupBy($"user").agg(stdev_pop($"duration"))

Spark 1.5 and below(The original answer):

Spark 1.5 及以下原始答案):

Not so pretty and biased (same as the value returned from describe) but using formula:

不是那么漂亮和有偏见(与从 返回的值相同describe)但使用公式:

wikipedia sdev

wikipedia sdev

you can do something like this:

你可以做这样的事情:

import org.apache.spark.sql.functions.sqrt

selectedData
    .groupBy($"user")
    .agg((sqrt(
        avg($"duration" * $"duration") -
        avg($"duration") * avg($"duration")
     )).alias("duration_sd"))

You can of course create a function to reduce the clutter:

您当然可以创建一个函数来减少混乱:

import org.apache.spark.sql.Column
def mySd(col: Column): Column = {
    sqrt(avg(col * col) - avg(col) * avg(col))
}

df.groupBy($"user").agg(mySd($"duration").alias("duration_sd"))

It is also possible to use Hive UDF:

也可以使用 Hive UDF:

df.registerTempTable("df")
sqlContext.sql("""SELECT user, stddev(duration)
                  FROM df
                  GROUP BY user""")


Source of the image: https://en.wikipedia.org/wiki/Standard_deviation

图片来源:https: //en.wikipedia.org/wiki/Standard_deviation

回答by JFRANCK

The accepted code does not compile, as it has a typo (as pointed out by MRez). The snippet below works and is tested.

接受的代码无法编译,因为它有一个错字(如 MRez 所指出的)。下面的代码段有效并经过测试。

For Spark 2.0+:

对于 Spark 2.0+

import org.apache.spark.sql.functions._
val _avg_std = df.groupBy("user").agg(
        avg(col("duration").alias("avg")),
        stddev(col("duration").alias("stdev")),
        stddev_pop(col("duration").alias("stdev_pop")),
        stddev_samp(col("duration").alias("stdev_samp"))
        )