Python 如何根据 PySpark 中的数组值进行过滤?

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/36014082/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-19 17:17:14  来源:igfitidea点击:

How to filter based on array value in PySpark?

pythonapache-sparkdataframepysparkapache-spark-sql

提问by Suhas Chandramouli

My Schema:

我的架构:

|-- Canonical_URL: string (nullable = true)
 |-- Certifications: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- Certification_Authority: string (nullable = true)
 |    |    |-- End: string (nullable = true)
 |    |    |-- License: string (nullable = true)
 |    |    |-- Start: string (nullable = true)
 |    |    |-- Title: string (nullable = true)
 |-- CompanyId: string (nullable = true)
 |-- Country: string (nullable = true)
|-- vendorTags: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- score: double (nullable = true)
 |    |    |-- vendor: string (nullable = true)

I tried the below query to select nested fields from vendorTags

我尝试了以下查询来从中选择嵌套字段 vendorTags

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts")

How can I query the nested fields in whereclause like below in PySpark

如何where在 PySpark 中查询如下子句中的嵌套字段

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'")

or

或者

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.score > 123.123456")

something like this..

像这样的东西..

I tried the above queries only to get the below error

我尝试了上述查询只是为了得到以下错误

df3 = sqlContext.sql("select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'")
16/03/15 13:16:02 INFO ParseDriver: Parsing command: select vendorTags.vendor from globalcontacts where vendorTags.vendor = 'alpha'
16/03/15 13:16:03 INFO ParseDriver: Parse Completed
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/lib/spark/python/pyspark/sql/context.py", line 583, in sql
    return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
  File "/usr/lib/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py", line 813, in __call__
  File "/usr/lib/spark/python/pyspark/sql/utils.py", line 51, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.AnalysisException: u"cannot resolve '(vendorTags.vendor = cast(alpha as double))' due to data type mismatch: differing types in '(vendorTags.vendor = cast(alpha as double))' (array<string> and double).; line 1 pos 71"

回答by zero323

For equality based queries you can use array_contains:

对于基于相等的查询,您可以使用array_contains

df = sc.parallelize([(1, [1, 2, 3]), (2, [4, 5, 6])]).toDF(["k", "v"])
df.createOrReplaceTempView("df")

# With SQL
sqlContext.sql("SELECT * FROM df WHERE array_contains(v, 1)")

# With DSL
from pyspark.sql.functions import array_contains
df.where(array_contains("v", 1))

If you want to use more complex predicates you'll have to either explodeor use an UDF, for example something like this:

如果你想使用更复杂的谓词,你必须要么explode使用 UDF,要么使用 UDF,例如这样的:

from pyspark.sql.types import BooleanType
from pyspark.sql.functions import udf 

def exists(f):
    return udf(lambda xs: any(f(x) for x in xs), BooleanType())

df.where(exists(lambda x: x > 3)("v"))

In Spark 2.4. or later it is also possible to use higher order functions

在 Spark 2.4 中。或者以后也可以使用高阶函数

from pyspark.sql.functions import expr

df.where(expr("""aggregate(
    transform(v, x -> x > 3),
    false, 
    (x, y) -> x or y
)"""))

or

或者

df.where(expr("""
    exists(v, x -> x > 3)
"""))

Python wrappers should be available in 3.1 (SPARK-30681).

Python 包装器应该在 3.1 ( SPARK-30681) 中可用。

回答by Hyman

In spark 2.4 you can filter array values using filter function in sql API.

在 spark 2.4 中,您可以使用 sql API 中的 filter 函数过滤数组值。

https://spark.apache.org/docs/2.4.0/api/sql/index.html#filter

https://spark.apache.org/docs/2.4.0/api/sql/index.html#filter

Here's example in pyspark. In the example we filter out all array values which are empty strings:

这是pyspark中的示例。在示例中,我们过滤掉所有空字符串的数组值:

df = df.withColumn("ArrayColumn", expr("filter(ArrayColumn, x -> x != '')"))