Repository: spark
Updated Branches:
  refs/heads/master 38a416f03 -> ac0b2b788


[SPARK-5588] [SQL] support select/filter by SQL expression

```
df.selectExpr('a + 1', 'abs(age)')
df.filter('age > 3')
df[ df.age > 3 ]
df[ ['age', 'name'] ]
```

Author: Davies Liu <dav...@databricks.com>

Closes #4359 from davies/select_expr and squashes the following commits:

d99856b [Davies Liu] support select/filter by SQL expression


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ac0b2b78
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ac0b2b78
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ac0b2b78

Branch: refs/heads/master
Commit: ac0b2b788ff144970d6fdbdc445367772770458d
Parents: 38a416f
Author: Davies Liu <dav...@databricks.com>
Authored: Wed Feb 4 11:34:46 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Feb 4 11:34:46 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/api/python/PythonUtils.scala   | 11 +++-
 python/pyspark/sql.py                           | 53 ++++++++++++++++----
 .../main/scala/org/apache/spark/sql/Dsl.scala   | 11 ----
 3 files changed, 53 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index b7cfc8b..acbaba6 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.api.python
 
-import java.io.{File, InputStream, IOException, OutputStream}
+import java.io.{File}
+import java.util.{List => JList}
 
+import scala.collection.JavaConversions._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.SparkContext
@@ -44,4 +46,11 @@ private[spark] object PythonUtils {
   def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
     sc.parallelize(List("a", null, "b"))
   }
+
+  /**
+   * Convert list of T into seq of T (for calling API with varargs)
+   */
+  def toSeq[T](cols: JList[T]): Seq[T] = {
+    cols.toList.toSeq
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 74305de..a266cde 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -2128,7 +2128,7 @@ class DataFrame(object):
             raise ValueError("should sort by at least one column")
         jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         self._sc._gateway._gateway_client)
-        jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols))
+        jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     sortBy = sort
@@ -2159,13 +2159,20 @@ class DataFrame(object):
 
         >>> df['age'].collect()
         [Row(age=2), Row(age=5)]
+        >>> df[ ["name", "age"]].collect()
+        [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+        >>> df[ df.age > 3 ].collect()
+        [Row(age=5, name=u'Bob')]
         """
         if isinstance(item, basestring):
             jc = self._jdf.apply(item)
             return Column(jc, self.sql_ctx)
-
-        # TODO projection
-        raise IndexError
+        elif isinstance(item, Column):
+            return self.filter(item)
+        elif isinstance(item, list):
+            return self.select(*item)
+        else:
+            raise IndexError("unexpected index: %s" % item)
 
     def __getattr__(self, name):
         """ Return the column by given name
@@ -2194,18 +2201,44 @@ class DataFrame(object):
             cols = ["*"]
         jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         self._sc._gateway._gateway_client)
-        jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+        jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+        return DataFrame(jdf, self.sql_ctx)
+
+    def selectExpr(self, *expr):
+        """
+        Selects a set of SQL expressions. This is a variant of
+        `select` that accepts SQL expressions.
+
+        >>> df.selectExpr("age * 2", "abs(age)").collect()
+        [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+        """
+        jexpr = ListConverter().convert(expr, 
self._sc._gateway._gateway_client)
+        jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
         return DataFrame(jdf, self.sql_ctx)
 
     def filter(self, condition):
-        """ Filtering rows using the given condition.
+        """ Filtering rows using the given condition, which could be
+        Column expression or string of SQL expression.
+
+        where() is an alias for filter().
 
         >>> df.filter(df.age > 3).collect()
         [Row(age=5, name=u'Bob')]
         >>> df.where(df.age == 2).collect()
         [Row(age=2, name=u'Alice')]
+
+        >>> df.filter("age > 3").collect()
+        [Row(age=5, name=u'Bob')]
+        >>> df.where("age = 2").collect()
+        [Row(age=2, name=u'Alice')]
         """
-        return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx)
+        if isinstance(condition, basestring):
+            jdf = self._jdf.filter(condition)
+        elif isinstance(condition, Column):
+            jdf = self._jdf.filter(condition._jc)
+        else:
+            raise TypeError("condition should be string or Column")
+        return DataFrame(jdf, self.sql_ctx)
 
     where = filter
 
@@ -2223,7 +2256,7 @@ class DataFrame(object):
         """
         jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         self._sc._gateway._gateway_client)
-        jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+        jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
         return GroupedDataFrame(jdf, self.sql_ctx)
 
     def agg(self, *exprs):
@@ -2338,7 +2371,7 @@ class GroupedDataFrame(object):
             assert all(isinstance(c, Column) for c in exprs), "all exprs 
should be Column"
             jcols = ListConverter().convert([c._jc for c in exprs[1:]],
                                             
self.sql_ctx._sc._gateway._gateway_client)
-            jdf = self._jdf.agg(exprs[0]._jc, 
self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
+            jdf = self._jdf.agg(exprs[0]._jc, 
self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
         return DataFrame(jdf, self.sql_ctx)
 
     @dfapi
@@ -2633,7 +2666,7 @@ class Dsl(object):
         jcols = ListConverter().convert([_to_java_column(c) for c in cols],
                                         sc._gateway._gateway_client)
         jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
-                                       sc._jvm.Dsl.toColumns(jcols))
+                                       sc._jvm.PythonUtils.toSeq(jcols))
         return Column(jc)
 
     @staticmethod

http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
index 8cf59f0..50f442d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
@@ -17,11 +17,8 @@
 
 package org.apache.spark.sql
 
-import java.util.{List => JList}
-
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{TypeTag, typeTag}
-import scala.collection.JavaConversions._
 
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions._
@@ -169,14 +166,6 @@ object Dsl {
   /** Computes the absolutle value. */
   def abs(e: Column): Column = Abs(e.expr)
 
-  /**
-   * This is a private API for Python
-   * TODO: move this to a private package
-   */
-  def toColumns(cols: JList[Column]): Seq[Column] = {
-    cols.toList.toSeq
-  }
-
   
//////////////////////////////////////////////////////////////////////////////////////////////
   
//////////////////////////////////////////////////////////////////////////////////////////////
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to