Repository: spark
Updated Branches:
  refs/heads/master a3afa4a1b -> 5c78be7a5


[SPARK-5799][SQL] Compute aggregation function on specified numeric columns

Compute aggregation function on specified numeric columns. For example:

    val df = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, 
"d")).toDataFrame("key", "value1", "value2", "rest")
    df.groupBy("key").min("value2")

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #4592 from viirya/specific_cols_agg and squashes the following commits:

9446896 [Liang-Chi Hsieh] For comments.
314c4cd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into 
specific_cols_agg
353fad7 [Liang-Chi Hsieh] For python unit tests.
54ed0c4 [Liang-Chi Hsieh] Address comments.
b079e6b [Liang-Chi Hsieh] Remove duplicate codes.
55100fb [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into 
specific_cols_agg
880c2ac [Liang-Chi Hsieh] Fix Python style checks.
4c63a01 [Liang-Chi Hsieh] Fix pyspark.
b1a24fc [Liang-Chi Hsieh] Address comments.
2592f29 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into 
specific_cols_agg
27069c3 [Liang-Chi Hsieh] Combine functions and add varargs annotation.
371a3f7 [Liang-Chi Hsieh] Compute aggregation function on specified numeric 
columns.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c78be7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c78be7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c78be7a

Branch: refs/heads/master
Commit: 5c78be7a515fc2fc92cda0517318e7b5d85762f4
Parents: a3afa4a
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Mon Feb 16 10:06:11 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Feb 16 10:06:11 2015 -0800

----------------------------------------------------------------------
 python/pyspark/sql/dataframe.py                 | 74 ++++++++++++++++----
 python/pyspark/sql/functions.py                 |  2 +
 .../org/apache/spark/sql/DataFrameImpl.scala    |  4 +-
 .../org/apache/spark/sql/GroupedData.scala      | 57 ++++++++++++---
 .../org/apache/spark/sql/DataFrameSuite.scala   | 12 ++++
 5 files changed, 123 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1438fe5..28a59e7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -664,6 +664,18 @@ def dfapi(f):
     return _api
 
 
+def df_varargs_api(f):
+    def _api(self, *args):
+        jargs = ListConverter().convert(args,
+                                        
self.sql_ctx._sc._gateway._gateway_client)
+        name = f.__name__
+        jdf = getattr(self._jdf, 
name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+        return DataFrame(jdf, self.sql_ctx)
+    _api.__name__ = f.__name__
+    _api.__doc__ = f.__doc__
+    return _api
+
+
 class GroupedData(object):
 
     """
@@ -714,30 +726,60 @@ class GroupedData(object):
         [Row(age=2, count=1), Row(age=5, count=1)]
         """
 
-    @dfapi
-    def mean(self):
+    @df_varargs_api
+    def mean(self, *cols):
         """Compute the average value for each numeric columns
-        for each group. This is an alias for `avg`."""
+        for each group. This is an alias for `avg`.
 
-    @dfapi
-    def avg(self):
+        >>> df.groupBy().mean('age').collect()
+        [Row(AVG(age#0)=3.5)]
+        >>> df3.groupBy().mean('age', 'height').collect()
+        [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+        """
+
+    @df_varargs_api
+    def avg(self, *cols):
         """Compute the average value for each numeric columns
-        for each group."""
+        for each group.
 
-    @dfapi
-    def max(self):
+        >>> df.groupBy().avg('age').collect()
+        [Row(AVG(age#0)=3.5)]
+        >>> df3.groupBy().avg('age', 'height').collect()
+        [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+        """
+
+    @df_varargs_api
+    def max(self, *cols):
         """Compute the max value for each numeric columns for
-        each group. """
+        each group.
 
-    @dfapi
-    def min(self):
+        >>> df.groupBy().max('age').collect()
+        [Row(MAX(age#0)=5)]
+        >>> df3.groupBy().max('age', 'height').collect()
+        [Row(MAX(age#4)=5, MAX(height#5)=85)]
+        """
+
+    @df_varargs_api
+    def min(self, *cols):
         """Compute the min value for each numeric column for
-        each group."""
+        each group.
 
-    @dfapi
-    def sum(self):
+        >>> df.groupBy().min('age').collect()
+        [Row(MIN(age#0)=2)]
+        >>> df3.groupBy().min('age', 'height').collect()
+        [Row(MIN(age#4)=2, MIN(height#5)=80)]
+        """
+
+    @df_varargs_api
+    def sum(self, *cols):
         """Compute the sum for each numeric columns for each
-        group."""
+        group.
+
+        >>> df.groupBy().sum('age').collect()
+        [Row(SUM(age#0)=7)]
+        >>> df3.groupBy().sum('age', 'height').collect()
+        [Row(SUM(age#4)=7, SUM(height#5)=165)]
+        """
 
 
 def _create_column_from_literal(literal):
@@ -945,6 +987,8 @@ def _test():
     globs['sqlCtx'] = SQLContext(sc)
     globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', 
age=5)]).toDF()
     globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', 
height=85)]).toDF()
+    globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+                                  Row(name='Bob', age=5, height=85)]).toDF()
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.dataframe, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)

http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 39aa550..d0e0906 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -158,6 +158,8 @@ def _test():
     globs['sqlCtx'] = SQLContext(sc)
     globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', 
age=5)]).toDF()
     globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', 
height=85)]).toDF()
+    globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+                                  Row(name='Bob', age=5, height=85)]).toDF()
     (failure_count, test_count) = doctest.testmod(
         pyspark.sql.dataframe, globs=globs,
         optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)

http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 7b7efbe..9eb0c13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -88,12 +88,12 @@ private[sql] class DataFrameImpl protected[sql](
     }
   }
 
-  protected[sql] def numericColumns: Seq[Expression] = {
+  protected[sql] def numericColumns(): Seq[Expression] = {
     schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
       queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
     }
   }
-
+ 
   override def toDF(colNames: String*): DataFrame = {
     require(schema.size == colNames.size,
       "The number of columns doesn't match.\n" +

http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 0868013..a5a677b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConversions._
 import org.apache.spark.sql.catalyst.analysis.Star
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.types.NumericType
+
 
 
 /**
@@ -39,13 +41,30 @@ class GroupedData protected[sql](df: DataFrameImpl, 
groupingExprs: Seq[Expressio
       df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, 
df.logicalPlan))
   }
 
-  private[this] def aggregateNumericColumns(f: Expression => Expression): 
Seq[NamedExpression] = {
-    df.numericColumns.map { c =>
+  private[this] def aggregateNumericColumns(colNames: String*)(f: Expression 
=> Expression)
+    : Seq[NamedExpression] = {
+
+    val columnExprs = if (colNames.isEmpty) {
+      // No columns specified. Use all numeric columns.
+      df.numericColumns
+    } else {
+      // Make sure all specified columns are numeric
+      colNames.map { colName =>
+        val namedExpr = df.resolve(colName)
+        if (!namedExpr.dataType.isInstanceOf[NumericType]) {
+          throw new AnalysisException(
+            s""""$colName" is not a numeric column. """ +
+            "Aggregation function can only be performed on a numeric column.")
+        }
+        namedExpr
+      }
+    }
+    columnExprs.map { c =>
       val a = f(c)
       Alias(a, a.toString)()
     }
   }
-
+ 
   private[this] def strToExpr(expr: String): (Expression => Expression) = {
     expr.toLowerCase match {
       case "avg" | "average" | "mean" => Average
@@ -152,30 +171,50 @@ class GroupedData protected[sql](df: DataFrameImpl, 
groupingExprs: Seq[Expressio
   /**
    * Compute the average value for each numeric columns for each group. This 
is an alias for `avg`.
    * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the average values for 
them.
    */
-  def mean(): DataFrame = aggregateNumericColumns(Average)
-
+  @scala.annotation.varargs
+  def mean(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames:_*)(Average)
+  }
+ 
   /**
    * Compute the max value for each numeric columns for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the max values for them.
    */
-  def max(): DataFrame = aggregateNumericColumns(Max)
+  @scala.annotation.varargs
+  def max(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames:_*)(Max)
+  }
 
   /**
    * Compute the mean value for each numeric columns for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the mean values for them.
    */
-  def avg(): DataFrame = aggregateNumericColumns(Average)
+  @scala.annotation.varargs
+  def avg(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames:_*)(Average)
+  }
 
   /**
    * Compute the min value for each numeric column for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the min values for them.
    */
-  def min(): DataFrame = aggregateNumericColumns(Min)
+  @scala.annotation.varargs
+  def min(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames:_*)(Min)
+  }
 
   /**
    * Compute the sum for each numeric columns for each group.
    * The resulting [[DataFrame]] will also contain the grouping columns.
+   * When specified columns are given, only compute the sum for them.
    */
-  def sum(): DataFrame = aggregateNumericColumns(Sum)
+  @scala.annotation.varargs
+  def sum(colNames: String*): DataFrame = {
+    aggregateNumericColumns(colNames:_*)(Sum)
+  }    
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f0cd436..524571d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -162,6 +162,18 @@ class DataFrameSuite extends QueryTest {
       testData2.groupBy("a").agg(Map("b" -> "sum")),
       Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
     )
+
+    val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+      .toDF("key", "value1", "value2", "rest")
+
+    checkAnswer(
+      df1.groupBy("key").min(),
+      df1.groupBy("key").min("value1", "value2").collect()
+    )
+    checkAnswer(
+      df1.groupBy("key").min("value2"),
+      Seq(Row("a", 0), Row("b", 4))
+    )
   }
 
   test("agg without groups") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to