Python Pyspark Dataframe 通过过滤分组

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/42826502/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 22:11:01  来源:igfitidea点击:

Pyspark Dataframe group by filtering

pythonapache-sparkpysparkapache-spark-sql

提问by Lijju Mathew

I have a data frame as below

我有一个如下的数据框

cust_id   req    req_met
-------   ---    -------
 1         r1      1
 1         r2      0
 1         r2      1
 2         r1      1
 3         r1      1
 3         r2      1
 4         r1      0
 5         r1      1
 5         r2      0
 5         r1      1

I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be

我必须看看客户,看看他们有多少要求,看看他们是否至少满足过一次。同一客户和要求可以有多个记录,一个满足和不满足。在上述情况下,我的输出应该是

cust_id
-------
  1
  2
  3

What I have done is

我所做的是

# say initial dataframe is df
df1 = df\
    .groupby('cust_id')\
    .countdistinct('req')\
    .alias('num_of_req')\
    .sum('req_met')\
    .alias('sum_req_met')

df2 = df1.filter(df1.num_of_req == df1.sum_req_met)

But in few cases it is not getting correct results

但在少数情况下,它没有得到正确的结果

How can this be done ?

如何才能做到这一点 ?

回答by titipata

First, I'll just prepare toy dataset from given above,

首先,我将准备上面给出的玩具数据集,

from pyspark.sql.functions import col
import pyspark.sql.functions as fn

df = spark.createDataFrame([[1, 'r1', 1],
 [1, 'r2', 0],
 [1, 'r2', 1],
 [2, 'r1', 1],
 [3, 'r1', 1],
 [3, 'r2', 1],
 [4, 'r1', 0],
 [5, 'r1', 1],
 [5, 'r2', 0],
 [5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))

I do the same thing by group by cust_idand reqthen count the req_met. After that, I create function to floor those requirement to just 0, 1

我按 group by 做同样的事情cust_idreq然后计算req_met. 之后,我创建函数将这些要求降低到 0, 1

def floor_req(r):
    if r >= 1:
        return 1
    else:
        return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))

Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.

现在,我们可以通过计算不同的需求数量和满足的需求总数来检查每个客户是否满足所有需求。

df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'), 
                                                 fn.count('req').alias('n_req'))

Finally, you just have to check if two columns are equal:

最后,您只需要检查两列是否相等:

df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()

回答by Ashish Singh

 select cust_id from  
(select cust_id , MIN(sum_value) as m from 
( select cust_id,req ,sum(req_met) as sum_value from <data_frame> group by cust_id,req )
 temp group by cust_id )temp1 
where m>0 ;

This will give desired result

这将给出预期的结果