Python PySpark 和广播连接示例
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/34053302/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
PySpark and broadcast join example
提问by user3803714
I am using Spark 1.3
我正在使用 Spark 1.3
# Read from text file, parse it and then do some basic filtering to get data1
data1.registerTempTable('data1')
# Read from text file, parse it and then do some basic filtering to get data1
data2.registerTempTable('data2')
# Perform join
data_joined = data1.join(data2, data1.id == data2.id);
My data is quite skewed and data2 (few KB) << data1 (10s of GB) and the performance is quite bad. I was reading about broadcast join, but not sure how I can do the same using Python API.
我的数据很不平衡,data2(几KB)<<data1(10s of GB),性能很差。我正在阅读有关广播连接的信息,但不确定如何使用 Python API 执行相同的操作。
回答by zero323
Spark 1.3 doesn't support broadcast joins using DataFrame. In Spark >= 1.5.0 you can use broadcast
function to apply broadcast joins:
Spark 1.3 不支持使用 DataFrame 的广播连接。在 Spark >= 1.5.0 中,您可以使用broadcast
函数来应用广播连接:
from pyspark.sql.functions import broadcast
data1.join(broadcast(data2), data1.id == data2.id)
For older versions the only option is to convert to RDD and apply the same logic as in other languages. Roughly something like this:
对于旧版本,唯一的选择是转换为 RDD 并应用与其他语言相同的逻辑。大致是这样的:
from pyspark.sql import Row
from pyspark.sql.types import StructType
# Create a dictionary where keys are join keys
# and values are lists of rows
data2_bd = sc.broadcast(
data2.map(lambda r: (r.id, r)).groupByKey().collectAsMap())
# Define a new row with fields from both DFs
output_row = Row(*data1.columns + data2.columns)
# And an output schema
output_schema = StructType(data1.schema.fields + data2.schema.fields)
# Given row x, extract a list of corresponding rows from broadcast
# and output a list of merged rows
def gen_rows(x):
return [output_row(*x + y) for y in data2_bd.value.get(x.id, [])]
# flatMap and create a new data frame
joined = data1.rdd.flatMap(lambda row: gen_rows(row)).toDF(output_schema)
回答by y durga prasad
This code is working in spark-2.0.2-bin-hadoop2.7 version
此代码适用于 spark-2.0.2-bin-hadoop2.7 版本
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
spark = SparkSession.builder.appName("Python Spark SQL basic
example").config("spark.some.config.option", "some-value").getOrCreate()
df2 = spark.read.csv("D:\trans_mar.txt",sep="^");
df1=spark.read.csv("D:\trans_feb.txt",sep="^");
print(df1.join(broadcast(df2),df2._c77==df1._c77).take(10))