pandas 将函数应用于 Spark DataFrame 中的所有单元格
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/54489344/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Apply a function to all cells in Spark DataFrame
提问by Steven
I'm trying to convert some Pandas code to Spark for scaling. myfuncis a wrapper to a complex API that takes a string and returns a new string (meaning I can't use vectorized functions).
我正在尝试将一些 Pandas 代码转换为 Spark 以进行缩放。myfunc是一个复杂 API 的包装器,它接受一个字符串并返回一个新字符串(这意味着我不能使用矢量化函数)。
def myfunc(ds):
for attribute, value in ds.items():
value = api_function(attribute, value)
ds[attribute] = value
return ds
df = df.apply(myfunc, axis='columns')
myfunctakes a DataSeries, breaks it up into individual cells, calls the API for each cell, and builds a new DataSeries with the same column names. This effectively modifies all cells in the DataFrame.
myfunc获取一个 DataSeries,将其分解为单个单元格,为每个单元格调用 API,并构建一个具有相同列名的新 DataSeries。这有效地修改了 DataFrame 中的所有单元格。
I'm new to Spark and I want to translate this logic using pyspark. I've converted my pandas DataFrame to Spark:
我是 Spark 的新手,我想使用pyspark. 我已将我的 Pandas DataFrame 转换为 Spark:
spark = SparkSession.builder.appName('My app').getOrCreate()
spark_schema = StructType([StructField(c, StringType(), True) for c in df.columns])
spark_df = spark.createDataFrame(df, schema=spark_schema)
This is where I get lost. Do I need a UDF, a pandas_udf? How do I iterate across all cells and return a new string for each using myfunc? spark_df.foreach()doesn't return anything and it doesn't have a map()function.
这是我迷路的地方。我需要一个UDF,一个pandas_udf吗?如何遍历所有单元格并为每个使用返回一个新字符串myfunc?spark_df.foreach()不返回任何东西,也没有map()函数。
I can modify myfuncfrom DataSeries-> DataSeriesto string-> stringif necessary.
我可以修改myfunc从DataSeries- >DataSeries到string> -string如果需要的话。
回答by Jason
Option 1: Use a UDF on One Column at a Time
选项 1:一次对一列使用 UDF
The simplest approach would be to rewrite your function to take a string as an argument (so that it is string-> string) and use a UDF. There's a nice example here. This works on one column at a time. So, if your DataFramehas a reasonable number of columns, you can apply the UDF to each column one at a time:
最简单的方法是重写您的函数以将字符串作为参数(因此它是string-> string)并使用 UDF。有一个很好的例子在这里。这一次适用于一列。因此,如果您DataFrame有合理数量的列,您可以一次将 UDF 应用于每一列:
from pyspark.sql.functions import col
new_df = df.select(udf(col("col1")), udf(col("col2")), ...)
Example
例子
df = sc.parallelize([[1, 4], [2,5], [3,6]]).toDF(["col1", "col2"])
df.show()
+----+----+
|col1|col2|
+----+----+
| 1| 4|
| 2| 5|
| 3| 6|
+----+----+
def plus1_udf(x):
return x + 1
plus1 = spark.udf.register("plus1", plus1_udf)
new_df = df.select(plus1(col("col1")), plus1(col("col2")))
new_df.show()
+-----------+-----------+
|plus1(col1)|plus1(col2)|
+-----------+-----------+
| 2| 5|
| 3| 6|
| 4| 7|
+-----------+-----------+
Option 2: Map the entire DataFrame at once
选项 2:一次映射整个 DataFrame
mapis available for Scala DataFrames, but, at the moment, not in PySpark.
The lower-level RDDAPI does have a mapfunction in PySpark. So, if you have too many columns to transform one at a time, you could operate on every single cell in the DataFramelike this:
map可用于 Scala DataFrame,但目前在 PySpark 中不可用。较低级别的RDDAPImap在 PySpark中有一个功能。因此,如果您有太多列无法一次转换一列,您可以DataFrame像这样对每个单元格进行操作:
def map_fn(row):
return [api_function(x) for (column, x) in row.asDict().items()
column_names = df.columns
new_df = df.rdd.map(map_fn).toDF(df.columns)
Example
例子
df = sc.parallelize([[1, 4], [2,5], [3,6]]).toDF(["col1", "col2"])
def map_fn(row):
return [value + 1 for (_, value) in row.asDict().items()]
columns = df.columns
new_df = df.rdd.map(map_fn).toDF(columns)
new_df.show()
+----+----+
|col1|col2|
+----+----+
| 2| 5|
| 3| 6|
| 4| 7|
+----+----+
Context
语境
The documentationof foreachonly gives the example of printing, but we can verify looking at the codethat it indeed does not return anything.
该文件中foreach只给出了印刷的例子,但我们可以验证看代码,它确实不返回任何东西。
You can read about pandas_udfin this post, but it seems that it is most suited to vectorized functions, which, as you pointed out, you can't use because of api_function.
您可以pandas_udf在这篇文章中阅读,但它似乎最适合矢量化函数,正如您所指出的,由于api_function.
回答by Steven
The solution is:
解决办法是:
udf_func = udf(func, StringType())
for col_name in spark_df.columns:
spark_df = spark_df.withColumn(col_name, udf_func(lit(col_name), col_name))
return spark_df.toPandas()
There are 3 key insights that helped me figure this out:
有 3 个关键见解帮助我解决了这个问题:
- If you use
withColumnwith the name of an existing column (col_name), Spark "overwrites"/shadows the original column. This essentially gives the appearance of editing the column directly as if it were mutable. - By creating a loop across the original columns and reusing the same DataFrame variable
spark_df, I use the same principle to simulate a mutable DataFrame, creating a chain of column-wise transformations, each time "overwriting" a column (per #1 - see below) - Spark
UDFsexpect all parameters to beColumntypes, which means it attempts to resolve column values for each parameter. Becauseapi_function's first parameter is a literal value that will be the same for all rows in the vector, you must use thelit()function. Simply passing col_name to the function will attempt to extract the column values for that column. As far as I could tell, passingcol_nameis equivalent to passingcol(col_name).
- 如果您使用
withColumn现有列的名称 (col_name),Spark 将“覆盖”/隐藏原始列。这基本上提供了直接编辑列的外观,就好像它是可变的一样。 - 通过在原始列上创建一个循环并重
spark_df用相同的 DataFrame 变量,我使用相同的原理来模拟可变的 DataFrame,创建一系列逐列转换,每次“覆盖”一列(每 #1 - 见下文) - Spark
UDFs期望所有参数都是Column类型,这意味着它尝试解析每个参数的列值。因为api_function的第一个参数是一个文字值,对于向量中的所有行都相同,所以您必须使用该lit()函数。简单地将 col_name 传递给函数将尝试提取该列的列值。据我所知,通过col_name等同于通过col(col_name)。
Assuming 3 columns 'a', 'b' and 'c', unrolling this concept would look like this:
假设有 3 列“a”、“b”和“c”,展开这个概念看起来像这样:
spark_df = spark_df.withColumn('a', udf_func(lit('a'), 'a')
.withColumn('b', udf_func(lit('b'), 'b')
.withColumn('c', udf_func(lit('c'), 'c')

