scala 使用作为字符串数组的行字段过滤火花数据框
声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow
原文地址: http://stackoverflow.com/questions/34833653/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me):
StackOverFlow
filter spark dataframe with row field that is an array of strings
提问by navicore
Using Spark 1.5 and Scala 2.10.6
使用 Spark 1.5 和 Scala 2.10.6
I'm trying to filter a dataframe via a field "tags" that is an array of strings. Looking for all rows that have the tag 'private'.
我正在尝试通过作为字符串数组的字段“标签”过滤数据框。寻找所有带有“private”标签的行。
val report = df.select("*")
.where(df("tags").contains("private"))
getting:
得到:
Exception in thread "main" org.apache.spark.sql.AnalysisException: cannot resolve 'Contains(tags, private)' due to data type mismatch: argument 1 requires string type, however, 'tags' is of array type.;
线程“main”org.apache.spark.sql.AnalysisException 中的异常:由于数据类型不匹配,无法解析“Contains(tags, private)”:参数 1 需要字符串类型,但是,“tags”是数组类型。;
Is the filter method better suited?
过滤方法是否更合适?
UPDATED:
更新:
the data is coming from cassandra adapter but a minimal example that shows what I'm trying to do and also gets the above error is:
数据来自 cassandra 适配器,但一个最小的例子显示了我正在尝试做的事情并且还得到了上述错误:
def testData (sc: SparkContext): DataFrame = {
val stringRDD = sc.parallelize(Seq("""
{ "name": "ed",
"tags": ["red", "private"]
}""",
"""{ "name": "fred",
"tags": ["public", "blue"]
}""")
)
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
import sqlContext.implicits._
sqlContext.read.json(stringRDD)
}
def run(sc: SparkContext) {
val df1 = testData(sc)
df1.show()
val report = df1.select("*")
.where(df1("tags").contains("private"))
report.show()
}
UPDATED: the tags array can be any length and the 'private' tag can be in any position
更新:标签数组可以是任何长度,“私有”标签可以在任何位置
UPDATED: one solution that works: UDF
更新:一种有效的解决方案:UDF
val filterPriv = udf {(tags: mutable.WrappedArray[String]) => tags.contains("private")}
val report = df1.filter(filterPriv(df1("tags")))
回答by Robert Dodier
I think if you use where(array_contains(...))it will work. Here's my result:
我认为如果你使用where(array_contains(...))它会起作用。这是我的结果:
scala> import org.apache.spark.SparkContext
import org.apache.spark.SparkContext
scala> import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.DataFrame
scala> def testData (sc: SparkContext): DataFrame = {
| val stringRDD = sc.parallelize(Seq
| ("""{ "name": "ned", "tags": ["blue", "big", "private"] }""",
| """{ "name": "albert", "tags": ["private", "lumpy"] }""",
| """{ "name": "zed", "tags": ["big", "private", "square"] }""",
| """{ "name": "jed", "tags": ["green", "small", "round"] }""",
| """{ "name": "ed", "tags": ["red", "private"] }""",
| """{ "name": "fred", "tags": ["public", "blue"] }"""))
| val sqlContext = new org.apache.spark.sql.SQLContext(sc)
| import sqlContext.implicits._
| sqlContext.read.json(stringRDD)
| }
testData: (sc: org.apache.spark.SparkContext)org.apache.spark.sql.DataFrame
scala>
| val df = testData (sc)
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]
scala> val report = df.select ("*").where (array_contains (df("tags"), "private"))
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]
scala> report.show
+------+--------------------+
| name| tags|
+------+--------------------+
| ned|[blue, big, private]|
|albert| [private, lumpy]|
| zed|[big, private, sq...|
| ed| [red, private]|
+------+--------------------+
Note that it works if you write where(array_contains(df("tags"), "private")), but if you write where(df("tags").array_contains("private"))(more directly analogous to what you wrote originally) it fails with array_contains is not a member of org.apache.spark.sql.Column. Looking at the source code for Column, I see there's some stuff to handle contains(constructing a Containsinstance for that) but not array_contains. Maybe that's an oversight.
请注意,如果您编写where(array_contains(df("tags"), "private")),它会起作用,但是如果您编写where(df("tags").array_contains("private"))(更直接类似于您最初编写的内容),它会因array_contains is not a member of org.apache.spark.sql.Column. 查看 的源代码Column,我看到有一些东西需要处理contains(Contains为此构建一个实例),但没有array_contains。也许这是一个疏忽。
回答by Aravind Yarram
You can use ordinal to refer to the json array's for e.g. in your case df("tags")(0). Here is a working sample
您可以使用 ordinal 来引用 json 数组的 for 例如在您的情况下df("tags")(0)。这是一个工作示例
scala> val stringRDD = sc.parallelize(Seq("""
| { "name": "ed",
| "tags": ["private"]
| }""",
| """{ "name": "fred",
| "tags": ["public"]
| }""")
| )
stringRDD: org.apache.spark.rdd.RDD[String] = ParallelCollectionRDD[87] at parallelize at <console>:22
scala> import sqlContext.implicits._
import sqlContext.implicits._
scala> sqlContext.read.json(stringRDD)
res28: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]
scala> val df=sqlContext.read.json(stringRDD)
df: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]
scala> df.columns
res29: Array[String] = Array(name, tags)
scala> df.dtypes
res30: Array[(String, String)] = Array((name,StringType), (tags,ArrayType(StringType,true)))
scala> val report = df.select("*").where(df("tags")(0).contains("private"))
report: org.apache.spark.sql.DataFrame = [name: string, tags: array<string>]
scala> report.show
+----+-------------+
|name| tags|
+----+-------------+
| ed|List(private)|
+----+-------------+

