Repository: spark
Updated Branches:
  refs/heads/branch-1.6 02748c953 -> 40a5db561


[SPARK-11410] [PYSPARK] Add python bindings for repartition and sortW…

…ithinPartitions.

Author: Nong Li <n...@databricks.com>

Closes #9504 from nongli/spark-11410.

(cherry picked from commit 1ab72b08601a1c8a674bdd3fab84d9804899b2c7)
Signed-off-by: Davies Liu <davies....@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/40a5db56
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/40a5db56
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/40a5db56

Branch: refs/heads/branch-1.6
Commit: 40a5db56169b87eecf574cf8e15e89caf4836ee4
Parents: 02748c9
Author: Nong Li <n...@databricks.com>
Authored: Fri Nov 6 15:48:20 2015 -0800
Committer: Davies Liu <davies....@gmail.com>
Committed: Fri Nov 6 15:48:45 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 117 ++++++++++++++++++++++++++++++-----
 1 file changed, 101 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/40a5db56/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 765a451..b97c94d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -423,6 +423,67 @@ class DataFrame(object):
         return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
 
     @since(1.3)
+    def repartition(self, numPartitions, *cols):
+        """
+        Returns a new :class:`DataFrame` partitioned by the given partitioning 
expressions. The
+        resulting DataFrame is hash partitioned.
+
+        ``numPartitions`` can be an int to specify the target number of 
partitions or a Column.
+        If it is a Column, it will be used as the first partitioning column. 
If not specified,
+        the default number of partitions is used.
+
+        .. versionchanged:: 1.6
+           Added optional arguments to specify the partitioning columns. Also 
made numPartitions
+           optional if partitioning columns are specified.
+
+        >>> df.repartition(10).rdd.getNumPartitions()
+        10
+        >>> data = df.unionAll(df).repartition("age")
+        >>> data.show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  2|Alice|
+        |  5|  Bob|
+        |  5|  Bob|
+        +---+-----+
+        >>> data = data.repartition(7, "age")
+        >>> data.show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  5|  Bob|
+        |  2|Alice|
+        |  2|Alice|
+        +---+-----+
+        >>> data.rdd.getNumPartitions()
+        7
+        >>> data = data.repartition("name", "age")
+        >>> data.show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  5|  Bob|
+        |  5|  Bob|
+        |  2|Alice|
+        |  2|Alice|
+        +---+-----+
+        """
+        if isinstance(numPartitions, int):
+            if len(cols) == 0:
+                return DataFrame(self._jdf.repartition(numPartitions), 
self.sql_ctx)
+            else:
+                return DataFrame(
+                    self._jdf.repartition(numPartitions, self._jcols(*cols)), 
self.sql_ctx)
+        elif isinstance(numPartitions, (basestring, Column)):
+            cols = (numPartitions, ) + cols
+            return DataFrame(self._jdf.repartition(self._jcols(*cols)), 
self.sql_ctx)
+        else:
+            raise TypeError("numPartitions should be an int or Column")
+
+    @since(1.3)
     def distinct(self):
         """Returns a new :class:`DataFrame` containing the distinct rows in 
this :class:`DataFrame`.
 
@@ -589,6 +650,26 @@ class DataFrame(object):
                 jdf = self._jdf.join(other._jdf, on._jc, how)
         return DataFrame(jdf, self.sql_ctx)
 
+    @since(1.6)
+    def sortWithinPartitions(self, *cols, **kwargs):
+        """Returns a new :class:`DataFrame` with each partition sorted by the 
specified column(s).
+
+        :param cols: list of :class:`Column` or column names to sort by.
+        :param ascending: boolean or list of boolean (default True).
+            Sort ascending vs. descending. Specify list for multiple sort 
orders.
+            If a list is specified, length of the list must equal length of 
the `cols`.
+
+        >>> df.sortWithinPartitions("age", ascending=False).show()
+        +---+-----+
+        |age| name|
+        +---+-----+
+        |  2|Alice|
+        |  5|  Bob|
+        +---+-----+
+        """
+        jdf = self._jdf.sortWithinPartitions(self._sort_cols(cols, kwargs))
+        return DataFrame(jdf, self.sql_ctx)
+
     @ignore_unicode_prefix
     @since(1.3)
     def sort(self, *cols, **kwargs):
@@ -613,22 +694,7 @@ class DataFrame(object):
         >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
         [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
         """
-        if not cols:
-            raise ValueError("should sort by at least one column")
-        if len(cols) == 1 and isinstance(cols[0], list):
-            cols = cols[0]
-        jcols = [_to_java_column(c) for c in cols]
-        ascending = kwargs.get('ascending', True)
-        if isinstance(ascending, (bool, int)):
-            if not ascending:
-                jcols = [jc.desc() for jc in jcols]
-        elif isinstance(ascending, list):
-            jcols = [jc if asc else jc.desc()
-                     for asc, jc in zip(ascending, jcols)]
-        else:
-            raise TypeError("ascending can only be boolean or list, but got 
%s" % type(ascending))
-
-        jdf = self._jdf.sort(self._jseq(jcols))
+        jdf = self._jdf.sort(self._sort_cols(cols, kwargs))
         return DataFrame(jdf, self.sql_ctx)
 
     orderBy = sort
@@ -650,6 +716,25 @@ class DataFrame(object):
             cols = cols[0]
         return self._jseq(cols, _to_java_column)
 
+    def _sort_cols(self, cols, kwargs):
+        """ Return a JVM Seq of Columns that describes the sort order
+        """
+        if not cols:
+            raise ValueError("should sort by at least one column")
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+        jcols = [_to_java_column(c) for c in cols]
+        ascending = kwargs.get('ascending', True)
+        if isinstance(ascending, (bool, int)):
+            if not ascending:
+                jcols = [jc.desc() for jc in jcols]
+        elif isinstance(ascending, list):
+            jcols = [jc if asc else jc.desc()
+                     for asc, jc in zip(ascending, jcols)]
+        else:
+            raise TypeError("ascending can only be boolean or list, but got 
%s" % type(ascending))
+        return self._jseq(jcols)
+
     @since("1.3.1")
     def describe(self, *cols):
         """Computes statistics for numeric columns.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to