Repository: spark
Updated Branches:
  refs/heads/master dc48ba9f9 -> c84d91692


[SPARK-6957] [SPARK-6958] [SQL] improve API compatibility to pandas

```
select(['cola', 'colb'])

groupby(['colA', 'colB'])
groupby([df.colA, df.colB])

df.sort('A', ascending=True)
df.sort(['A', 'B'], ascending=True)
df.sort(['A', 'B'], ascending=[1, 0])
```

cc rxin

Author: Davies Liu <dav...@databricks.com>

Closes #5544 from davies/compatibility and squashes the following commits:

4944058 [Davies Liu] add docstrings
adb2816 [Davies Liu] Merge branch 'master' of github.com:apache/spark into 
compatibility
bcbbcab [Davies Liu] support ascending as list
8dabdf0 [Davies Liu] improve API compatibility to pandas


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c84d9169
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c84d9169
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c84d9169

Branch: refs/heads/master
Commit: c84d91692aa25c01882bcc3f9fd5de3cfa786195
Parents: dc48ba9
Author: Davies Liu <dav...@databricks.com>
Authored: Fri Apr 17 11:29:27 2015 -0500
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Apr 17 11:29:27 2015 -0500

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py | 96 +++++++++++++++++++++++++-----------
 python/pyspark/sql/functions.py | 11 ++---
 python/pyspark/sql/tests.py     |  2 +-
 3 files changed, 70 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c84d9169/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index b9a3e6c..326d22e 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -485,13 +485,17 @@ class DataFrame(object):
         return DataFrame(jdf, self.sql_ctx)
 
     @ignore_unicode_prefix
-    def sort(self, *cols):
+    def sort(self, *cols, **kwargs):
         """Returns a new :class:`DataFrame` sorted by the specified column(s).
 
-        :param cols: list of :class:`Column` to sort by.
+        :param cols: list of :class:`Column` or column names to sort by.
+        :param ascending: sort by ascending order or not, could be bool, int
+             or list of bool, int (default: True).
 
         >>> df.sort(df.age.desc()).collect()
         [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+        >>> df.sort("age", ascending=False).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
         >>> df.orderBy(df.age.desc()).collect()
         [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
         >>> from pyspark.sql.functions import *
@@ -499,16 +503,42 @@ class DataFrame(object):
         [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
         >>> df.orderBy(desc("age"), "name").collect()
         [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+        >>> df.orderBy(["age", "name"], ascending=[0, 1]).collect()
+        [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
         """
         if not cols:
             raise ValueError("should sort by at least one column")
-        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
-                                        self._sc._gateway._gateway_client)
-        jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+        jcols = [_to_java_column(c) for c in cols]
+        ascending = kwargs.get('ascending', True)
+        if isinstance(ascending, (bool, int)):
+            if not ascending:
+                jcols = [jc.desc() for jc in jcols]
+        elif isinstance(ascending, list):
+            jcols = [jc if asc else jc.desc()
+                     for asc, jc in zip(ascending, jcols)]
+        else:
+            raise TypeError("ascending can only be bool or list, but got %s" % 
type(ascending))
+
+        jdf = self._jdf.sort(self._jseq(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     orderBy = sort
 
+    def _jseq(self, cols, converter=None):
+        """Return a JVM Seq of Columns from a list of Column or names"""
+        return _to_seq(self.sql_ctx._sc, cols, converter)
+
+    def _jcols(self, *cols):
+        """Return a JVM Seq of Columns from a list of Column or column names
+
+        If `cols` has only one list in it, cols[0] will be used as the list.
+        """
+        if len(cols) == 1 and isinstance(cols[0], list):
+            cols = cols[0]
+        return self._jseq(cols, _to_java_column)
+
     def describe(self, *cols):
         """Computes statistics for numeric columns.
 
@@ -523,9 +553,7 @@ class DataFrame(object):
         min     2
         max     5
         """
-        cols = ListConverter().convert(cols,
-                                       
self.sql_ctx._sc._gateway._gateway_client)
-        jdf = self._jdf.describe(self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols))
+        jdf = self._jdf.describe(self._jseq(cols))
         return DataFrame(jdf, self.sql_ctx)
 
     @ignore_unicode_prefix
@@ -607,9 +635,7 @@ class DataFrame(object):
         >>> df.select(df.name, (df.age + 10).alias('age')).collect()
         [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
         """
-        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
-                                        self._sc._gateway._gateway_client)
-        jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        jdf = self._jdf.select(self._jcols(*cols))
         return DataFrame(jdf, self.sql_ctx)
 
     def selectExpr(self, *expr):
@@ -620,8 +646,9 @@ class DataFrame(object):
         >>> df.selectExpr("age * 2", "abs(age)").collect()
         [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
         """
-        jexpr = ListConverter().convert(expr, 
self._sc._gateway._gateway_client)
-        jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
+        if len(expr) == 1 and isinstance(expr[0], list):
+            expr = expr[0]
+        jdf = self._jdf.selectExpr(self._jseq(expr))
         return DataFrame(jdf, self.sql_ctx)
 
     @ignore_unicode_prefix
@@ -659,6 +686,8 @@ class DataFrame(object):
         so we can run aggregation on them. See :class:`GroupedData`
         for all the available aggregate functions.
 
+        :func:`groupby` is an alias for :func:`groupBy`.
+
         :param cols: list of columns to group by.
             Each element should be a column name (string) or an expression 
(:class:`Column`).
 
@@ -668,12 +697,14 @@ class DataFrame(object):
         [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
         >>> df.groupBy(df.name).avg().collect()
         [Row(name=u'Alice', AVG(age)=2.0), Row(name=u'Bob', AVG(age)=5.0)]
+        >>> df.groupBy(['name', df.age]).count().collect()
+        [Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
         """
-        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
-                                        self._sc._gateway._gateway_client)
-        jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        jdf = self._jdf.groupBy(self._jcols(*cols))
         return GroupedData(jdf, self.sql_ctx)
 
+    groupby = groupBy
+
     def agg(self, *exprs):
         """ Aggregate on the entire :class:`DataFrame` without groups
         (shorthand for ``df.groupBy.agg()``).
@@ -744,9 +775,7 @@ class DataFrame(object):
         if thresh is None:
             thresh = len(subset) if how == 'any' else 1
 
-        cols = ListConverter().convert(subset, 
self.sql_ctx._sc._gateway._gateway_client)
-        cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
-        return DataFrame(self._jdf.na().drop(thresh, cols), self.sql_ctx)
+        return DataFrame(self._jdf.na().drop(thresh, self._jseq(subset)), 
self.sql_ctx)
 
     def fillna(self, value, subset=None):
         """Replace null values, alias for ``na.fill()``.
@@ -799,9 +828,7 @@ class DataFrame(object):
             elif not isinstance(subset, (list, tuple)):
                 raise ValueError("subset should be a list or tuple of column 
names")
 
-            cols = ListConverter().convert(subset, 
self.sql_ctx._sc._gateway._gateway_client)
-            cols = self.sql_ctx._sc._jvm.PythonUtils.toSeq(cols)
-            return DataFrame(self._jdf.na().fill(value, cols), self.sql_ctx)
+            return DataFrame(self._jdf.na().fill(value, self._jseq(subset)), 
self.sql_ctx)
 
     @ignore_unicode_prefix
     def withColumn(self, colName, col):
@@ -862,10 +889,8 @@ def dfapi(f):
 
 def df_varargs_api(f):
     def _api(self, *args):
-        jargs = ListConverter().convert(args,
-                                        
self.sql_ctx._sc._gateway._gateway_client)
         name = f.__name__
-        jdf = getattr(self._jdf, 
name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+        jdf = getattr(self._jdf, name)(_to_seq(self.sql_ctx._sc, args))
         return DataFrame(jdf, self.sql_ctx)
     _api.__name__ = f.__name__
     _api.__doc__ = f.__doc__
@@ -912,9 +937,8 @@ class GroupedData(object):
         else:
             # Columns
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
-            jcols = ListConverter().convert([c._jc for c in exprs[1:]],
-                                            
self.sql_ctx._sc._gateway._gateway_client)
-            jdf = self._jdf.agg(exprs[0]._jc, 
self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+            jdf = self._jdf.agg(exprs[0]._jc,
+                                _to_seq(self.sql_ctx._sc, [c._jc for c in 
exprs[1:]]))
         return DataFrame(jdf, self.sql_ctx)
 
     @dfapi
@@ -1006,6 +1030,19 @@ def _to_java_column(col):
     return jcol
 
 
+def _to_seq(sc, cols, converter=None):
+    """
+    Convert a list of Column (or names) into a JVM Seq of Column.
+
+    An optional `converter` could be used to convert items in `cols`
+    into JVM Column objects.
+    """
+    if converter:
+        cols = [converter(c) for c in cols]
+    jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
+    return sc._jvm.PythonUtils.toSeq(jcols)
+
+
 def _unary_op(name, doc="unary operator"):
     """ Create a method for given unary operator """
     def _(self):
@@ -1177,8 +1214,7 @@ class Column(object):
             cols = cols[0]
         cols = [c._jc if isinstance(c, Column) else 
_create_column_from_literal(c) for c in cols]
         sc = SparkContext._active_spark_context
-        jcols = ListConverter().convert(cols, sc._gateway._gateway_client)
-        jc = getattr(self._jc, "in")(sc._jvm.PythonUtils.toSeq(jcols))
+        jc = getattr(self._jc, "in")(_to_seq(sc, cols))
         return Column(jc)
 
     # order

http://git-wip-us.apache.org/repos/asf/spark/blob/c84d9169/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1d65369..bb47923 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -23,13 +23,11 @@ import sys
 if sys.version < "3":
     from itertools import imap as map
 
-from py4j.java_collections import ListConverter
-
 from pyspark import SparkContext
 from pyspark.rdd import _prepare_for_python_RDD
 from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
 from pyspark.sql.types import StringType
-from pyspark.sql.dataframe import Column, _to_java_column
+from pyspark.sql.dataframe import Column, _to_java_column, _to_seq
 
 
 __all__ = ['countDistinct', 'approxCountDistinct', 'udf']
@@ -87,8 +85,7 @@ def countDistinct(col, *cols):
     [Row(c=2)]
     """
     sc = SparkContext._active_spark_context
-    jcols = ListConverter().convert([_to_java_column(c) for c in cols], 
sc._gateway._gateway_client)
-    jc = sc._jvm.functions.countDistinct(_to_java_column(col), 
sc._jvm.PythonUtils.toSeq(jcols))
+    jc = sc._jvm.functions.countDistinct(_to_java_column(col), _to_seq(sc, 
cols, _to_java_column))
     return Column(jc)
 
 
@@ -138,9 +135,7 @@ class UserDefinedFunction(object):
 
     def __call__(self, *cols):
         sc = SparkContext._active_spark_context
-        jcols = ListConverter().convert([_to_java_column(c) for c in cols],
-                                        sc._gateway._gateway_client)
-        jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+        jc = self._judf.apply(_to_seq(sc, cols, _to_java_column))
         return Column(jc)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c84d9169/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 6691e8c..aa3aa1d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -282,7 +282,7 @@ class SQLTests(ReusedPySparkTestCase):
             StructField("struct1", StructType([StructField("b", ShortType(), 
False)]), False),
             StructField("list1", ArrayType(ByteType(), False), False),
             StructField("null1", DoubleType(), True)])
-        df = self.sqlCtx.applySchema(rdd, schema)
+        df = self.sqlCtx.createDataFrame(rdd, schema)
         results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, 
x.int1, x.float1, x.date1,
                                     x.time1, x.map1["a"], x.struct1.b, 
x.list1, x.null1))
         r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to