如何在 Spark RDD (Java) 中按索引获取元素

声明:本页面是StackOverFlow热门问题的中英对照翻译,遵循CC BY-SA 4.0协议,如果您需要使用它,必须同样遵循CC BY-SA许可,注明原文地址和作者信息,同时你必须将它归于原作者(不是我):StackOverFlow 原文地址: http://stackoverflow.com/questions/26828815/
Warning: these are provided under cc-by-sa 4.0 license. You are free to use/share it, But you must attribute it to the original authors (not me): StackOverFlow

提示:将鼠标放在中文语句上可以显示对应的英文。显示中英文
时间:2020-08-11 03:21:26  来源:igfitidea点击:

How to get element by Index in Spark RDD (Java)

javaapache-sparkrdd

提问by progNewbie

I know the method rdd.first() which gives me the first element in an RDD.

我知道 rdd.first() 方法,它为我提供 RDD 中的第一个元素。

Also there is the method rdd.take(num) Which gives me the first "num" elements.

还有一个方法 rdd.take(num) 它给了我第一个“num”元素。

But isn't there a possibility to get an element by index?

但是不是有可能通过索引获取元素吗?

Thanks.

谢谢。

采纳答案by maasg

This should be possible by first indexing the RDD. The transformation zipWithIndexprovides a stable indexing, numbering each element in its original order.

这应该可以通过首先索引 RDD 来实现。转换zipWithIndex提供了稳定的索引,按其原始顺序对每个元素进行编号。

Given: rdd = (a,b,c)

鉴于: rdd = (a,b,c)

val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))

To lookup an element by index, this form is not useful. First we need to use the index as key:

要按索引查找元素,这种形式没有用。首先我们需要使用索引作为键:

val indexKey = withIndex.map{case (k,v) => (v,k)}  //((0,a),(1,b),(2,c))

Now, it's possible to use the lookupaction in PairRDD to find an element by key:

现在,可以使用lookupPairRDD 中的操作来按键查找元素:

val b = indexKey.lookup(1) // Array(b)

If you're expecting to use lookupoften on the same RDD, I'd recommend to cache the indexKeyRDD to improve performance.

如果您希望lookup经常在同一个 RDD 上使用,我建议您缓存indexKeyRDD 以提高性能。

How to do this using the Java APIis an exercise left for the reader.

如何使用Java API做到这一点是留给读者的练习。

回答by yonran

I tried this class to fetch an item by index. First, when you construct new IndexedFetcher(rdd, itemClass), it counts the number of elements in each partition of the RDD. Then, when you call indexedFetcher.get(n), it runs a job on only the partition that contains that index.

我试过这个类按索引获取一个项目。首先,当您构造 时new IndexedFetcher(rdd, itemClass),它会计算 RDD 的每个分区中的元素数量。然后,当您调用 时indexedFetcher.get(n),它仅在包含该索引的分区上运行作业。

Note that I needed to compile this using Java 1.7 instead of 1.8; as of Spark 1.1.0, the bundled org.objectweb.asm within com.esotericsoftware.reflectasm cannot read Java 1.8 classes yet (throws IllegalStateException when you try to runJob a Java 1.8 function).

请注意,我需要使用 Java 1.7 而不是 1.8 来编译它;从 Spark 1.1.0 开始,com.esotericsoftware.reflectasm 中捆绑的 org.objectweb.asm 还不能读取 Java 1.8 类(当您尝试运行 Java 1.8 函数时抛出 IllegalStateException)。

import java.io.Serializable;

import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;

import scala.reflect.ClassTag;

public static class IndexedFetcher<E> implements Serializable {
    private static final long serialVersionUID = 1L;
    public final RDD<E> rdd;
    public Integer[] elementsPerPartitions;
    private Class<?> clazz;
    public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
        this.rdd = rdd;
        this.clazz = clazz;
        SparkContext context = this.rdd.context();
        ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
        elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
    }
    public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
        private static final long serialVersionUID = 1L;
        @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                count++;
                iterator.next();
            }
            return count;
        }
    }
    static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
        scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
        return function;
    }
    public E get(long index) {
        long remaining = index;
        long totalCount = 0;
        for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
            if (remaining < elementsPerPartitions[partition]) {
                return getWithinPartition(partition, remaining);
            }
            remaining -= elementsPerPartitions[partition];
            totalCount += elementsPerPartitions[partition];
        }
        throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
    }
    public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
        private static final long serialVersionUID = 1L;
        private final long indexWithinPartition;
        public FetchWithinPartitionFunction(long indexWithinPartition) {
            this.indexWithinPartition = indexWithinPartition;
        }
        @Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                E element = iterator.next();
                if (count == indexWithinPartition)
                    return element;
                count++;
            }
            throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
        }
    }
    public E getWithinPartition(int partition, long indexWithinPartition) {
        System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
        SparkContext context = rdd.context();
        scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
        scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
        ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
        E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
        return result[0];
    }
}

回答by Luke W

I got stuck on this for a while as well, so to expand on Maasg's answer but answering to look for a range of values by index for Java (you'll need to define the 4 variables at the top):

我也被困在这个问题上一段时间,所以要扩展 Maasg 的答案,但回答要按 Java 的索引查找一系列值(您需要在顶部定义 4 个变量):

DataFrame df;
SQLContext sqlContext;
Long start;
Long end;

JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());

Remember that when you run this code your cluster will need to have Java 8 (as a lambda expression is in use).

请记住,当您运行此代码时,您的集群将需要具有 Java 8(因为正在使用 lambda 表达式)。

Also, zipWithIndex is probably expensive!

另外, zipWithIndex 可能很贵!