scala 如何旋转 Spark DataFrame?
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/30244910/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
How to pivot Spark DataFrame?
提问by J Calbreath
I am starting to use Spark DataFrames and I need to be able to pivot the data to create multiple columns out of 1 column with multiple rows. There is built in functionality for that in Scalding and I believe in Pandas in Python, but I can't find anything for the new Spark Dataframe.
我开始使用 Spark DataFrames,我需要能够旋转数据以从具有多行的 1 列中创建多列。在 Scalding 中有内置的功能,我相信 Python 中的 Pandas,但我找不到新的 Spark Dataframe 的任何内容。
I assume I can write custom function of some sort that will do this but I'm not even sure how to start, especially since I am a novice with Spark. I anyone knows how to do this with built in functionality or suggestions for how to write something in Scala, it is greatly appreciated.
我假设我可以编写某种自定义函数来执行此操作,但我什至不确定如何开始,尤其是因为我是 Spark 的新手。我任何人都知道如何使用内置功能或有关如何在 Scala 中编写内容的建议来做到这一点,非常感谢。
回答by zero323
As mentionedby David AndersonSpark provides pivotfunction since version 1.6. General syntax looks as follows:
正如大卫·安德森所提到的,Sparkpivot从 1.6 版开始提供功能。一般语法如下所示:
df
.groupBy(grouping_columns)
.pivot(pivot_column, [values])
.agg(aggregate_expressions)
Usage examples using nycflights13and csvformat:
用法nycflights13和csv格式的用法和样例:
Python:
蟒蛇:
from pyspark.sql.functions import avg
flights = (sqlContext
.read
.format("csv")
.options(inferSchema="true", header="true")
.load("flights.csv")
.na.drop())
flights.registerTempTable("flights")
sqlContext.cacheTable("flights")
gexprs = ("origin", "dest", "carrier")
aggexpr = avg("arr_delay")
flights.count()
## 336776
%timeit -n10 flights.groupBy(*gexprs ).pivot("hour").agg(aggexpr).count()
## 10 loops, best of 3: 1.03 s per loop
Scala:
斯卡拉:
val flights = sqlContext
.read
.format("csv")
.options(Map("inferSchema" -> "true", "header" -> "true"))
.load("flights.csv")
flights
.groupBy($"origin", $"dest", $"carrier")
.pivot("hour")
.agg(avg($"arr_delay"))
Java:
爪哇:
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.*;
Dataset<Row> df = spark.read().format("csv")
.option("inferSchema", "true")
.option("header", "true")
.load("flights.csv");
df.groupBy(col("origin"), col("dest"), col("carrier"))
.pivot("hour")
.agg(avg(col("arr_delay")));
R / SparkR:
R/SparkR:
library(magrittr)
flights <- read.df("flights.csv", source="csv", header=TRUE, inferSchema=TRUE)
flights %>%
groupBy("origin", "dest", "carrier") %>%
pivot("hour") %>%
agg(avg(column("arr_delay")))
R / sparklyr
R / 闪闪发光
library(dplyr)
flights <- spark_read_csv(sc, "flights", "flights.csv")
avg.arr.delay <- function(gdf) {
expr <- invoke_static(
sc,
"org.apache.spark.sql.functions",
"avg",
"arr_delay"
)
gdf %>% invoke("agg", expr, list())
}
flights %>%
sdf_pivot(origin + dest + carrier ~ hour, fun.aggregate=avg.arr.delay)
SQL:
查询语句:
Note that PIVOT keyword in Spark SQL is supported starting from version 2.4.
请注意,从版本 2.4 开始支持 Spark SQL 中的 PIVOT 关键字。
CREATE TEMPORARY VIEW flights
USING csv
OPTIONS (header 'true', path 'flights.csv', inferSchema 'true') ;
SELECT * FROM (
SELECT origin, dest, carrier, arr_delay, hour FROM flights
) PIVOT (
avg(arr_delay)
FOR hour IN (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23)
);
Example data:
示例数据:
"year","month","day","dep_time","sched_dep_time","dep_delay","arr_time","sched_arr_time","arr_delay","carrier","flight","tailnum","origin","dest","air_time","distance","hour","minute","time_hour"
2013,1,1,517,515,2,830,819,11,"UA",1545,"N14228","EWR","IAH",227,1400,5,15,2013-01-01 05:00:00
2013,1,1,533,529,4,850,830,20,"UA",1714,"N24211","LGA","IAH",227,1416,5,29,2013-01-01 05:00:00
2013,1,1,542,540,2,923,850,33,"AA",1141,"N619AA","JFK","MIA",160,1089,5,40,2013-01-01 05:00:00
2013,1,1,544,545,-1,1004,1022,-18,"B6",725,"N804JB","JFK","BQN",183,1576,5,45,2013-01-01 05:00:00
2013,1,1,554,600,-6,812,837,-25,"DL",461,"N668DN","LGA","ATL",116,762,6,0,2013-01-01 06:00:00
2013,1,1,554,558,-4,740,728,12,"UA",1696,"N39463","EWR","ORD",150,719,5,58,2013-01-01 05:00:00
2013,1,1,555,600,-5,913,854,19,"B6",507,"N516JB","EWR","FLL",158,1065,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,709,723,-14,"EV",5708,"N829AS","LGA","IAD",53,229,6,0,2013-01-01 06:00:00
2013,1,1,557,600,-3,838,846,-8,"B6",79,"N593JB","JFK","MCO",140,944,6,0,2013-01-01 06:00:00
2013,1,1,558,600,-2,753,745,8,"AA",301,"N3ALAA","LGA","ORD",138,733,6,0,2013-01-01 06:00:00
Performance considerations:
性能考虑:
Generally speaking pivoting is an expensive operation.
一般来说,旋转是一项昂贵的操作。
if you can, try to provide
valueslist, as this avoids an extra hit to compute the uniques:vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loopin some cases it proved to be beneficial(likely no longer worth the effort in 2.0 or later) to
repartitionand / or pre-aggregate the datafor reshaping only, you can use
first: Pivot String column on Pyspark Dataframe
如果可以,请尝试提供
values列表,因为这样可以避免计算唯一性的额外命中:vs = list(range(25)) %timeit -n10 flights.groupBy(*gexprs ).pivot("hour", vs).agg(aggexpr).count() ## 10 loops, best of 3: 392 ms per loop在某些情况下,它被证明是有益的(在2.0 或更高版本中可能不再值得付出努力)
repartition和/或预聚合数据仅用于重塑,您可以使用
first:Pyspark Dataframe 上的 Pivot String 列
Related questions:
相关问题:
回答by J Calbreath
I overcame this by writing a for loop to dynamically create a SQL query. Say I have:
我通过编写一个 for 循环来动态创建 SQL 查询来克服这个问题。说我有:
id tag value
1 US 50
1 UK 100
1 Can 125
2 US 75
2 UK 150
2 Can 175
and I want:
而且我要:
id US UK Can
1 50 100 125
2 75 150 175
I can create a list with the value I want to pivot and then create a string containing the SQL query I need.
我可以创建一个包含我想要旋转的值的列表,然后创建一个包含我需要的 SQL 查询的字符串。
val countries = List("US", "UK", "Can")
val numCountries = countries.length - 1
var query = "select *, "
for (i <- 0 to numCountries-1) {
query += """case when tag = """" + countries(i) + """" then value else 0 end as """ + countries(i) + ", "
}
query += """case when tag = """" + countries.last + """" then value else 0 end as """ + countries.last + " from myTable"
myDataFrame.registerTempTable("myTable")
val myDF1 = sqlContext.sql(query)
I can create similar query to then do the aggregation. Not a very elegant solution but it works and is flexible for any list of values, which can also be passed in as an argument when your code is called.
我可以创建类似的查询然后进行聚合。不是一个非常优雅的解决方案,但它适用于任何值列表并且很灵活,也可以在调用代码时作为参数传入。
回答by David Anderson
A pivot operator has been added to the Spark dataframe API, and is part of Spark 1.6.
枢轴运算符已添加到 Spark 数据帧 API,并且是 Spark 1.6 的一部分。
See https://github.com/apache/spark/pull/7841for details.
有关详细信息,请参阅https://github.com/apache/spark/pull/7841。
回答by Al M
I have solved a similar problem using dataframes with the following steps:
我已经通过以下步骤使用数据帧解决了类似的问题:
Create columns for all your countries, with 'value' as the value:
为您的所有国家/地区创建列,以“值”作为值:
import org.apache.spark.sql.functions._
val countries = List("US", "UK", "Can")
val countryValue = udf{(countryToCheck: String, countryInRow: String, value: Long) =>
if(countryToCheck == countryInRow) value else 0
}
val countryFuncs = countries.map{country => (dataFrame: DataFrame) => dataFrame.withColumn(country, countryValue(lit(country), df("tag"), df("value"))) }
val dfWithCountries = Function.chain(countryFuncs)(df).drop("tag").drop("value")
Your dataframe 'dfWithCountries' will look like this:
您的数据框 'dfWithCountries' 将如下所示:
+--+--+---+---+
|id|US| UK|Can|
+--+--+---+---+
| 1|50| 0| 0|
| 1| 0|100| 0|
| 1| 0| 0|125|
| 2|75| 0| 0|
| 2| 0|150| 0|
| 2| 0| 0|175|
+--+--+---+---+
Now you can sum together all the values for your desired result:
现在,您可以将所需结果的所有值相加:
dfWithCountries.groupBy("id").sum(countries: _*).show
Result:
结果:
+--+-------+-------+--------+
|id|SUM(US)|SUM(UK)|SUM(Can)|
+--+-------+-------+--------+
| 1| 50| 100| 125|
| 2| 75| 150| 175|
+--+-------+-------+--------+
It's not a very elegant solution though. I had to create a chain of functions to add in all the columns. Also if I have lots of countries, I will expand my temporary data set to a very wide set with lots of zeroes.
不过,这不是一个非常优雅的解决方案。我必须创建一个函数链来添加到所有列中。此外,如果我有很多国家/地区,我会将我的临时数据集扩展到包含许多零的非常广泛的集。
回答by Mantas
There is simple and elegant solution.
有一个简单而优雅的解决方案。
scala> spark.sql("select * from k_tags limit 10").show()
+---------------+-------------+------+
| imsi| name| value|
+---------------+-------------+------+
|246021000000000| age| 37|
|246021000000000| gender|Female|
|246021000000000| arpu| 22|
|246021000000000| DeviceType| Phone|
|246021000000000|DataAllowance| 6GB|
+---------------+-------------+------+
scala> spark.sql("select * from k_tags limit 10").groupBy($"imsi").pivot("name").agg(min($"value")).show()
+---------------+-------------+----------+---+----+------+
| imsi|DataAllowance|DeviceType|age|arpu|gender|
+---------------+-------------+----------+---+----+------+
|246021000000000| 6GB| Phone| 37| 22|Female|
|246021000000001| 1GB| Phone| 72| 10| Male|
+---------------+-------------+----------+---+----+------+
回答by abasar
There are plenty of examples of pivot operation on dataset/dataframe, but I could not find many using SQL. Here is an example that worked for me.
在数据集/数据帧上有很多数据透视操作的例子,但我找不到很多使用 SQL 的例子。这是一个对我有用的例子。
create or replace temporary view faang
as SELECT stock.date AS `Date`,
stock.adj_close AS `Price`,
stock.symbol as `Symbol`
FROM stock
WHERE (stock.symbol rlike '^(FB|AAPL|GOOG|AMZN)$') and year(date) > 2010;
SELECT * from faang
PIVOT (max(price) for symbol in ('AAPL', 'FB', 'GOOG', 'AMZN')) order by date;
回答by Jaigates
Initially i adopted Al M's solution. Later took the same thought and rewrote this function as a transpose function.
最初我采用了 Al M 的解决方案。后来采取了同样的想法,将这个函数改写为转置函数。
This method transposes any df rows to columns of any data-format with using key and value column
此方法使用键和值列将任何 df 行转换为任何数据格式的列
for input csv
用于输入 csv
id,tag,value
1,US,50a
1,UK,100
1,Can,125
2,US,75
2,UK,150
2,Can,175
ouput
输出
+--+---+---+---+
|id| UK| US|Can|
+--+---+---+---+
| 2|150| 75|175|
| 1|100|50a|125|
+--+---+---+---+
transpose method :
转置方法:
def transpose(hc : HiveContext , df: DataFrame,compositeId: List[String], key: String, value: String) = {
val distinctCols = df.select(key).distinct.map { r => r(0) }.collect().toList
val rdd = df.map { row =>
(compositeId.collect { case id => row.getAs(id).asInstanceOf[Any] },
scala.collection.mutable.Map(row.getAs(key).asInstanceOf[Any] -> row.getAs(value).asInstanceOf[Any]))
}
val pairRdd = rdd.reduceByKey(_ ++ _)
val rowRdd = pairRdd.map(r => dynamicRow(r, distinctCols))
hc.createDataFrame(rowRdd, getSchema(df.schema, compositeId, (key, distinctCols)))
}
private def dynamicRow(r: (List[Any], scala.collection.mutable.Map[Any, Any]), colNames: List[Any]) = {
val cols = colNames.collect { case col => r._2.getOrElse(col.toString(), null) }
val array = r._1 ++ cols
Row(array: _*)
}
private def getSchema(srcSchema: StructType, idCols: List[String], distinctCols: (String, List[Any])): StructType = {
val idSchema = idCols.map { idCol => srcSchema.apply(idCol) }
val colSchema = srcSchema.apply(distinctCols._1)
val colsSchema = distinctCols._2.map { col => StructField(col.asInstanceOf[String], colSchema.dataType, colSchema.nullable) }
StructType(idSchema ++ colsSchema)
}
main snippet
主要片段
import java.util.Date
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.StructField
...
...
def main(args: Array[String]): Unit = {
val sc = new SparkContext(conf)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
val dfdata1 = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").option("inferSchema", "true")
.load("data.csv")
dfdata1.show()
val dfOutput = transpose(new HiveContext(sc), dfdata1, List("id"), "tag", "value")
dfOutput.show
}
回答by Kumar
Spark has been providing improvements to Pivoting the Spark DataFrame. A pivot function has been added to the Spark DataFrame API to Spark 1.6 version and it has a performance issue and that has been corrected in Spark 2.0
Spark 一直在对 Pivoting Spark DataFrame 提供改进。Spark 1.6 版本的 Spark DataFrame API 中添加了一个枢轴函数,它有一个性能问题,已在 Spark 2.0 中得到纠正
however, if you are using lower version; note that pivot is a very expensive operation hence, it is recommended to provide column data (if known) as an argument to function as shown below.
但是,如果您使用的是较低版本;请注意,pivot 是一项非常昂贵的操作,因此,建议提供列数据(如果已知)作为函数的参数,如下所示。
val countries = Seq("USA","China","Canada","Mexico")
val pivotDF = df.groupBy("Product").pivot("Country", countries).sum("Amount")
pivotDF.show()
This has been explained detailed at Pivoting and Unpivoting Spark DataFrame
这已在Pivoting and Unpivoting Spark DataFrame 中详细解释
Happy Learning !!
快乐学习!!
回答by parisni
The built-in spark pivot function is inefficient. The bellow implementation works on spark 2.4+ - the idea is to aggregate a map and extract the values as columns. The only limitation is it does not handle aggregate function in the pivoted columns, only column(s).
内置的火花枢轴功能效率低下。波纹管实现适用于 spark 2.4+ - 想法是聚合地图并将值提取为列。唯一的限制是它不处理旋转列中的聚合函数,只处理列。
On a 8M table, those functions applies on 3 secondes, versus 40 minutesin the built-in spark version:
在 8M 表上,这些函数适用于3 seconds,而在内置 spark 版本中为40 分钟:
# pass an optional list of string to avoid computation of columns
def pivot(df, group_by, key, aggFunction, levels=[]):
if not levels:
levels = [row[key] for row in df.filter(col(key).isNotNull()).groupBy(col(key)).agg(count(key)).select(key).collect()]
return df.filter(col(key).isin(*levels) == True).groupBy(group_by).agg(map_from_entries(collect_list(struct(key, expr(aggFunction)))).alias("group_map")).select([group_by] + ["group_map." + l for l in levels])
# Usage
pivot(df, "id", "key", "value")
pivot(df, "id", "key", "array(value)")
// pass an optional list of string to avoid computation of columns
def pivot(df: DataFrame, groupBy: Column, key: Column, aggFunct: String, _levels: List[String] = Nil): DataFrame = {
val levels =
if (_levels.isEmpty) df.filter(key.isNotNull).select(key).distinct().collect().map(row => row.getString(0)).toList
else _levels
df
.filter(key.isInCollection(levels))
.groupBy(groupBy)
.agg(map_from_entries(collect_list(struct(key, expr(aggFunct)))).alias("group_map"))
.select(groupBy.toString, levels.map(f => "group_map." + f): _*)
}
// Usage:
pivot(df, col("id"), col("key"), "value")
pivot(df, col("id"), col("key"), "array(value)")

