Python Pyspark Dataframe 通过过滤分组
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/42826502/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
Pyspark Dataframe group by filtering
提问by Lijju Mathew
I have a data frame as below
我有一个如下的数据框
cust_id req req_met
------- --- -------
1 r1 1
1 r2 0
1 r2 1
2 r1 1
3 r1 1
3 r2 1
4 r1 0
5 r1 1
5 r2 0
5 r1 1
I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be
我必须看看客户,看看他们有多少要求,看看他们是否至少满足过一次。同一客户和要求可以有多个记录,一个满足和不满足。在上述情况下,我的输出应该是
cust_id
-------
1
2
3
What I have done is
我所做的是
# say initial dataframe is df
df1 = df\
.groupby('cust_id')\
.countdistinct('req')\
.alias('num_of_req')\
.sum('req_met')\
.alias('sum_req_met')
df2 = df1.filter(df1.num_of_req == df1.sum_req_met)
But in few cases it is not getting correct results
但在少数情况下,它没有得到正确的结果
How can this be done ?
如何才能做到这一点 ?
回答by titipata
First, I'll just prepare toy dataset from given above,
首先,我将准备上面给出的玩具数据集,
from pyspark.sql.functions import col
import pyspark.sql.functions as fn
df = spark.createDataFrame([[1, 'r1', 1],
[1, 'r2', 0],
[1, 'r2', 1],
[2, 'r1', 1],
[3, 'r1', 1],
[3, 'r2', 1],
[4, 'r1', 0],
[5, 'r1', 1],
[5, 'r2', 0],
[5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))
I do the same thing by group by cust_id
and req
then count the req_met
. After that, I create function to floor those requirement to just 0, 1
我按 group by 做同样的事情cust_id
,req
然后计算req_met
. 之后,我创建函数将这些要求降低到 0, 1
def floor_req(r):
if r >= 1:
return 1
else:
return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))
Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.
现在,我们可以通过计算不同的需求数量和满足的需求总数来检查每个客户是否满足所有需求。
df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'),
fn.count('req').alias('n_req'))
Finally, you just have to check if two columns are equal:
最后,您只需要检查两列是否相等:
df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()
回答by Ashish Singh
select cust_id from
(select cust_id , MIN(sum_value) as m from
( select cust_id,req ,sum(req_met) as sum_value from <data_frame> group by cust_id,req )
temp group by cust_id )temp1
where m>0 ;
This will give desired result
这将给出预期的结果