Repository: spark Updated Branches: refs/heads/master 38a416f03 -> ac0b2b788
[SPARK-5588] [SQL] support select/filter by SQL expression ``` df.selectExpr('a + 1', 'abs(age)') df.filter('age > 3') df[ df.age > 3 ] df[ ['age', 'name'] ] ``` Author: Davies Liu <dav...@databricks.com> Closes #4359 from davies/select_expr and squashes the following commits: d99856b [Davies Liu] support select/filter by SQL expression Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ac0b2b78 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ac0b2b78 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ac0b2b78 Branch: refs/heads/master Commit: ac0b2b788ff144970d6fdbdc445367772770458d Parents: 38a416f Author: Davies Liu <dav...@databricks.com> Authored: Wed Feb 4 11:34:46 2015 -0800 Committer: Reynold Xin <r...@databricks.com> Committed: Wed Feb 4 11:34:46 2015 -0800 ---------------------------------------------------------------------- .../apache/spark/api/python/PythonUtils.scala | 11 +++- python/pyspark/sql.py | 53 ++++++++++++++++---- .../main/scala/org/apache/spark/sql/Dsl.scala | 11 ---- 3 files changed, 53 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index b7cfc8b..acbaba6 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -17,8 +17,10 @@ package org.apache.spark.api.python -import java.io.{File, InputStream, IOException, OutputStream} +import java.io.{File} +import java.util.{List => JList} +import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext @@ -44,4 +46,11 @@ private[spark] object PythonUtils { def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = { sc.parallelize(List("a", null, "b")) } + + /** + * Convert list of T into seq of T (for calling API with varargs) + */ + def toSeq[T](cols: JList[T]): Seq[T] = { + cols.toList.toSeq + } } http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/python/pyspark/sql.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 74305de..a266cde 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -2128,7 +2128,7 @@ class DataFrame(object): raise ValueError("should sort by at least one column") jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.sort(self._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) sortBy = sort @@ -2159,13 +2159,20 @@ class DataFrame(object): >>> df['age'].collect() [Row(age=2), Row(age=5)] + >>> df[ ["name", "age"]].collect() + [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)] + >>> df[ df.age > 3 ].collect() + [Row(age=5, name=u'Bob')] """ if isinstance(item, basestring): jc = self._jdf.apply(item) return Column(jc, self.sql_ctx) - - # TODO projection - raise IndexError + elif isinstance(item, Column): + return self.filter(item) + elif isinstance(item, list): + return self.select(*item) + else: + raise IndexError("unexpected index: %s" % item) def __getattr__(self, name): """ Return the column by given name @@ -2194,18 +2201,44 @@ class DataFrame(object): cols = ["*"] jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) + return DataFrame(jdf, self.sql_ctx) + + def selectExpr(self, *expr): + """ + Selects a set of SQL expressions. This is a variant of + `select` that accepts SQL expressions. + + >>> df.selectExpr("age * 2", "abs(age)").collect() + [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)] + """ + jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client) + jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr)) return DataFrame(jdf, self.sql_ctx) def filter(self, condition): - """ Filtering rows using the given condition. + """ Filtering rows using the given condition, which could be + Column expression or string of SQL expression. + + where() is an alias for filter(). >>> df.filter(df.age > 3).collect() [Row(age=5, name=u'Bob')] >>> df.where(df.age == 2).collect() [Row(age=2, name=u'Alice')] + + >>> df.filter("age > 3").collect() + [Row(age=5, name=u'Bob')] + >>> df.where("age = 2").collect() + [Row(age=2, name=u'Alice')] """ - return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + if isinstance(condition, basestring): + jdf = self._jdf.filter(condition) + elif isinstance(condition, Column): + jdf = self._jdf.filter(condition._jc) + else: + raise TypeError("condition should be string or Column") + return DataFrame(jdf, self.sql_ctx) where = filter @@ -2223,7 +2256,7 @@ class DataFrame(object): """ jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) - jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return GroupedDataFrame(jdf, self.sql_ctx) def agg(self, *exprs): @@ -2338,7 +2371,7 @@ class GroupedDataFrame(object): assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column" jcols = ListConverter().convert([c._jc for c in exprs[1:]], self.sql_ctx._sc._gateway._gateway_client) - jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols)) + jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) @dfapi @@ -2633,7 +2666,7 @@ class Dsl(object): jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client) jc = sc._jvm.Dsl.countDistinct(_to_java_column(col), - sc._jvm.Dsl.toColumns(jcols)) + sc._jvm.PythonUtils.toSeq(jcols)) return Column(jc) @staticmethod http://git-wip-us.apache.org/repos/asf/spark/blob/ac0b2b78/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala index 8cf59f0..50f442d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql -import java.util.{List => JList} - import scala.language.implicitConversions import scala.reflect.runtime.universe.{TypeTag, typeTag} -import scala.collection.JavaConversions._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ @@ -169,14 +166,6 @@ object Dsl { /** Computes the absolutle value. */ def abs(e: Column): Column = Abs(e.expr) - /** - * This is a private API for Python - * TODO: move this to a private package - */ - def toColumns(cols: JList[Column]): Seq[Column] = { - cols.toList.toSeq - } - ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org