Repository: spark
Updated Branches:
  refs/heads/master 48f42781d -> 4fa2fda88


[SPARK-2871] [PySpark] add RDD.lookup(key)

RDD.lookup(key)

        Return the list of values in the RDD for key `key`. This operation
        is done efficiently if the RDD has a known partitioner by only
        searching the partition that the key maps to.

        >>> l = range(1000)
        >>> rdd = sc.parallelize(zip(l, l), 10)
        >>> rdd.lookup(42)  # slow
        [42]
        >>> sorted = rdd.sortByKey()
        >>> sorted.lookup(42)  # fast
        [42]

It also clean up the code in RDD.py, and fix several bugs (related to 
preservesPartitioning).

Author: Davies Liu <davies....@gmail.com>

Closes #2093 from davies/lookup and squashes the following commits:

1789cd4 [Davies Liu] `f` in foreach could be generator or not.
2871b80 [Davies Liu] Merge branch 'master' into lookup
c6390ea [Davies Liu] address all comments
0f1bce8 [Davies Liu] add test case for lookup()
be0e8ba [Davies Liu] fix preservesPartitioning
eb1305d [Davies Liu] add RDD.lookup(key)


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4fa2fda8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4fa2fda8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4fa2fda8

Branch: refs/heads/master
Commit: 4fa2fda88fc7beebb579ba808e400113b512533b
Parents: 48f4278
Author: Davies Liu <davies....@gmail.com>
Authored: Wed Aug 27 13:18:33 2014 -0700
Committer: Josh Rosen <joshro...@apache.org>
Committed: Wed Aug 27 13:18:33 2014 -0700

----------------------------------------------------------------------
 python/pyspark/rdd.py | 211 +++++++++++++++++----------------------------
 1 file changed, 79 insertions(+), 132 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4fa2fda8/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 3191974..2d80fad 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -147,76 +147,6 @@ class BoundedFloat(float):
         return obj
 
 
-class MaxHeapQ(object):
-
-    """
-    An implementation of MaxHeap.
-
-    >>> import pyspark.rdd
-    >>> heap = pyspark.rdd.MaxHeapQ(5)
-    >>> [heap.insert(i) for i in range(10)]
-    [None, None, None, None, None, None, None, None, None, None]
-    >>> sorted(heap.getElements())
-    [0, 1, 2, 3, 4]
-    >>> heap = pyspark.rdd.MaxHeapQ(5)
-    >>> [heap.insert(i) for i in range(9, -1, -1)]
-    [None, None, None, None, None, None, None, None, None, None]
-    >>> sorted(heap.getElements())
-    [0, 1, 2, 3, 4]
-    >>> heap = pyspark.rdd.MaxHeapQ(1)
-    >>> [heap.insert(i) for i in range(9, -1, -1)]
-    [None, None, None, None, None, None, None, None, None, None]
-    >>> heap.getElements()
-    [0]
-    """
-
-    def __init__(self, maxsize):
-        # We start from q[1], so its children are always  2 * k
-        self.q = [0]
-        self.maxsize = maxsize
-
-    def _swim(self, k):
-        while (k > 1) and (self.q[k / 2] < self.q[k]):
-            self._swap(k, k / 2)
-            k = k / 2
-
-    def _swap(self, i, j):
-        t = self.q[i]
-        self.q[i] = self.q[j]
-        self.q[j] = t
-
-    def _sink(self, k):
-        N = self.size()
-        while 2 * k <= N:
-            j = 2 * k
-            # Here we test if both children are greater than parent
-            # if not swap with larger one.
-            if j < N and self.q[j] < self.q[j + 1]:
-                j = j + 1
-            if(self.q[k] > self.q[j]):
-                break
-            self._swap(k, j)
-            k = j
-
-    def size(self):
-        return len(self.q) - 1
-
-    def insert(self, value):
-        if (self.size()) < self.maxsize:
-            self.q.append(value)
-            self._swim(self.size())
-        else:
-            self._replaceRoot(value)
-
-    def getElements(self):
-        return self.q[1:]
-
-    def _replaceRoot(self, value):
-        if(self.q[1] > value):
-            self.q[1] = value
-            self._sink(1)
-
-
 def _parse_memory(s):
     """
     Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
@@ -248,6 +178,7 @@ class RDD(object):
         self.ctx = ctx
         self._jrdd_deserializer = jrdd_deserializer
         self._id = jrdd.id()
+        self._partitionFunc = None
 
     def _toPickleSerialization(self):
         if (self._jrdd_deserializer == PickleSerializer() or
@@ -325,8 +256,6 @@ class RDD(object):
         checkpointFile = self._jrdd.rdd().getCheckpointFile()
         if checkpointFile.isDefined():
             return checkpointFile.get()
-        else:
-            return None
 
     def map(self, f, preservesPartitioning=False):
         """
@@ -366,7 +295,7 @@ class RDD(object):
         """
         def func(s, iterator):
             return f(iterator)
-        return self.mapPartitionsWithIndex(func)
+        return self.mapPartitionsWithIndex(func, preservesPartitioning)
 
     def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
         """
@@ -416,7 +345,7 @@ class RDD(object):
         """
         def func(iterator):
             return ifilter(f, iterator)
-        return self.mapPartitions(func)
+        return self.mapPartitions(func, True)
 
     def distinct(self):
         """
@@ -561,7 +490,7 @@ class RDD(object):
         """
         return self.map(lambda v: (v, None)) \
             .cogroup(other.map(lambda v: (v, None))) \
-            .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \
+            .filter(lambda (k, vs): all(vs)) \
             .keys()
 
     def _reserialize(self, serializer=None):
@@ -616,7 +545,7 @@ class RDD(object):
         if numPartitions == 1:
             if self.getNumPartitions() > 1:
                 self = self.coalesce(1)
-            return self.mapPartitions(sortPartition)
+            return self.mapPartitions(sortPartition, True)
 
         # first compute the boundary of each part via sampling: we want to 
partition
         # the key-space into bins such that the bins have roughly the same
@@ -721,8 +650,8 @@ class RDD(object):
         def processPartition(iterator):
             for x in iterator:
                 f(x)
-            yield None
-        self.mapPartitions(processPartition).collect()  # Force evaluation
+            return iter([])
+        self.mapPartitions(processPartition).count()  # Force evaluation
 
     def foreachPartition(self, f):
         """
@@ -731,10 +660,15 @@ class RDD(object):
         >>> def f(iterator):
         ...      for x in iterator:
         ...           print x
-        ...      yield None
         >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
         """
-        self.mapPartitions(f).collect()  # Force evaluation
+        def func(it):
+            r = f(it)
+            try:
+                return iter(r)
+            except TypeError:
+                return iter([])
+        self.mapPartitions(func).count()  # Force evaluation
 
     def collect(self):
         """
@@ -767,18 +701,23 @@ class RDD(object):
         15
         >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 
1).cache().reduce(add)
         10
+        >>> sc.parallelize([]).reduce(add)
+        Traceback (most recent call last):
+            ...
+        ValueError: Can not reduce() empty RDD
         """
         def func(iterator):
-            acc = None
-            for obj in iterator:
-                if acc is None:
-                    acc = obj
-                else:
-                    acc = f(obj, acc)
-            if acc is not None:
-                yield acc
+            iterator = iter(iterator)
+            try:
+                initial = next(iterator)
+            except StopIteration:
+                return
+            yield reduce(f, iterator, initial)
+
         vals = self.mapPartitions(func).collect()
-        return reduce(f, vals)
+        if vals:
+            return reduce(f, vals)
+        raise ValueError("Can not reduce() empty RDD")
 
     def fold(self, zeroValue, op):
         """
@@ -1081,7 +1020,7 @@ class RDD(object):
             yield counts
 
         def mergeMaps(m1, m2):
-            for (k, v) in m2.iteritems():
+            for k, v in m2.iteritems():
                 m1[k] += v
             return m1
         return self.mapPartitions(countPartition).reduce(mergeMaps)
@@ -1117,24 +1056,10 @@ class RDD(object):
         [10, 9, 7, 6, 5, 4]
         """
 
-        def topNKeyedElems(iterator, key_=None):
-            q = MaxHeapQ(num)
-            for k in iterator:
-                if key_ is not None:
-                    k = (key_(k), k)
-                q.insert(k)
-            yield q.getElements()
-
-        def unKey(x, key_=None):
-            if key_ is not None:
-                x = [i[1] for i in x]
-            return x
-
         def merge(a, b):
-            return next(topNKeyedElems(a + b))
-        result = self.mapPartitions(
-            lambda i: topNKeyedElems(i, key)).reduce(merge)
-        return sorted(unKey(result, key), key=key)
+            return heapq.nsmallest(num, a + b, key)
+
+        return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, 
key)]).reduce(merge)
 
     def take(self, num):
         """
@@ -1174,13 +1099,13 @@ class RDD(object):
             left = num - len(items)
 
             def takeUpToNumLeft(iterator):
+                iterator = iter(iterator)
                 taken = 0
                 while taken < left:
                     yield next(iterator)
                     taken += 1
 
-            p = range(
-                partsScanned, min(partsScanned + numPartsToTry, totalParts))
+            p = range(partsScanned, min(partsScanned + numPartsToTry, 
totalParts))
             res = self.context.runJob(self, takeUpToNumLeft, p, True)
 
             items += res
@@ -1194,8 +1119,15 @@ class RDD(object):
 
         >>> sc.parallelize([2, 3, 4]).first()
         2
+        >>> sc.parallelize([]).first()
+        Traceback (most recent call last):
+            ...
+        ValueError: RDD is empty
         """
-        return self.take(1)[0]
+        rs = self.take(1)
+        if rs:
+            return rs[0]
+        raise ValueError("RDD is empty")
 
     def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, 
valueConverter=None):
         """
@@ -1420,13 +1352,13 @@ class RDD(object):
         """
         def reducePartition(iterator):
             m = {}
-            for (k, v) in iterator:
-                m[k] = v if k not in m else func(m[k], v)
+            for k, v in iterator:
+                m[k] = func(m[k], v) if k in m else v
             yield m
 
         def mergeMaps(m1, m2):
-            for (k, v) in m2.iteritems():
-                m1[k] = v if k not in m1 else func(m1[k], v)
+            for k, v in m2.iteritems():
+                m1[k] = func(m1[k], v) if k in m1 else v
             return m1
         return self.mapPartitions(reducePartition).reduce(mergeMaps)
 
@@ -1523,7 +1455,7 @@ class RDD(object):
             buckets = defaultdict(list)
             c, batch = 0, min(10 * numPartitions, 1000)
 
-            for (k, v) in iterator:
+            for k, v in iterator:
                 buckets[partitionFunc(k) % numPartitions].append((k, v))
                 c += 1
 
@@ -1546,7 +1478,7 @@ class RDD(object):
                         batch = max(batch / 1.5, 1)
                     c = 0
 
-            for (split, items) in buckets.iteritems():
+            for split, items in buckets.iteritems():
                 yield pack_long(split)
                 yield outputSerializer.dumps(items)
 
@@ -1616,7 +1548,7 @@ class RDD(object):
             merger.mergeCombiners(iterator)
             return merger.iteritems()
 
-        return shuffled.mapPartitions(_mergeCombiners)
+        return shuffled.mapPartitions(_mergeCombiners, True)
 
     def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
         """
@@ -1680,7 +1612,6 @@ class RDD(object):
         return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
                                  numPartitions).mapValues(lambda x: 
ResultIterable(x))
 
-    # TODO: add tests
     def flatMapValues(self, f):
         """
         Pass each value in the key-value pair RDD through a flatMap function
@@ -1770,9 +1701,8 @@ class RDD(object):
         [('b', 4), ('b', 5)]
         """
         def filter_func((key, vals)):
-            return len(vals[0]) > 0 and len(vals[1]) == 0
-        map_func = lambda (key, vals): [(key, val) for val in vals[0]]
-        return self.cogroup(other, 
numPartitions).filter(filter_func).flatMap(map_func)
+            return vals[0] and not vals[1]
+        return self.cogroup(other, 
numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
 
     def subtract(self, other, numPartitions=None):
         """
@@ -1785,7 +1715,7 @@ class RDD(object):
         """
         # note: here 'True' is just a placeholder
         rdd = other.map(lambda x: (x, True))
-        return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda 
tpl: tpl[0])
+        return self.map(lambda x: (x, True)).subtractByKey(rdd, 
numPartitions).keys()
 
     def keyBy(self, f):
         """
@@ -1925,9 +1855,8 @@ class RDD(object):
         Return the name of this RDD.
         """
         name_ = self._jrdd.name()
-        if not name_:
-            return None
-        return name_.encode('utf-8')
+        if name_:
+            return name_.encode('utf-8')
 
     def setName(self, name):
         """
@@ -1945,9 +1874,8 @@ class RDD(object):
         A description of this RDD and its recursive dependencies for debugging.
         """
         debug_string = self._jrdd.toDebugString()
-        if not debug_string:
-            return None
-        return debug_string.encode('utf-8')
+        if debug_string:
+            return debug_string.encode('utf-8')
 
     def getStorageLevel(self):
         """
@@ -1982,10 +1910,28 @@ class RDD(object):
         else:
             return self.getNumPartitions()
 
-    # TODO: `lookup` is disabled because we can't make direct comparisons based
-    # on the key; we need to compare the hash of the key to the hash of the
-    # keys in the pairs.  This could be an expensive operation, since those
-    # hashes aren't retained.
+    def lookup(self, key):
+        """
+        Return the list of values in the RDD for key `key`. This operation
+        is done efficiently if the RDD has a known partitioner by only
+        searching the partition that the key maps to.
+
+        >>> l = range(1000)
+        >>> rdd = sc.parallelize(zip(l, l), 10)
+        >>> rdd.lookup(42)  # slow
+        [42]
+        >>> sorted = rdd.sortByKey()
+        >>> sorted.lookup(42)  # fast
+        [42]
+        >>> sorted.lookup(1024)
+        []
+        """
+        values = self.filter(lambda (k, v): k == key).values()
+
+        if self._partitionFunc is not None:
+            return self.ctx.runJob(values, lambda x: x, 
[self._partitionFunc(key)], False)
+
+        return values.collect()
 
     def _is_pickled(self):
         """ Return this RDD is serialized by Pickle or not. """
@@ -2096,6 +2042,7 @@ class PipelinedRDD(RDD):
         self._jrdd_val = None
         self._jrdd_deserializer = self.ctx.serializer
         self._bypass_serializer = False
+        self._partitionFunc = prev._partitionFunc if 
self.preservesPartitioning else None
 
     @property
     def _jrdd(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to