Repository: spark
Updated Branches:
  refs/heads/master 895baf8f7 -> 42c592adb


[SPARK-7320] [SQL] Add Cube / Rollup for dataframe

This is a follow up for #6257, which broke the maven test.

Add cube & rollup for DataFrame
For example:
```scala
testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b"))
testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b"))
```

Author: Cheng Hao <hao.ch...@intel.com>

Closes #6304 from chenghao-intel/rollup and squashes the following commits:

04bb1de [Cheng Hao] move the table register/unregister into beforeAll/afterAll
a6069f1 [Cheng Hao] cancel the implicit keyword
ced4b8f [Cheng Hao] remove the unnecessary code changes
9959dfa [Cheng Hao] update the code as comments
e1d88aa [Cheng Hao] update the code as suggested
03bc3d9 [Cheng Hao] Remove the CubedData & RollupedData
5fd62d0 [Cheng Hao] hiden the CubedData & RollupedData
5ffb196 [Cheng Hao] Add Cube / Rollup for dataframe


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/42c592ad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/42c592ad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/42c592ad

Branch: refs/heads/master
Commit: 42c592adb381ff20832cce55e0849ed68dd7eee4
Parents: 895baf8
Author: Cheng Hao <hao.ch...@intel.com>
Authored: Wed May 20 19:58:22 2015 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Wed May 20 19:58:22 2015 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/DataFrame.scala  | 104 ++++++++++++++++++-
 .../org/apache/spark/sql/GroupedData.scala      |  92 +++++++++++-----
 .../sql/hive/HiveDataFrameAnalyticsSuite.scala  |  69 ++++++++++++
 3 files changed, 237 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/42c592ad/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index adad858..d78b4c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -685,7 +685,53 @@ class DataFrame private[sql](
    * @since 1.3.0
    */
   @scala.annotation.varargs
-  def groupBy(cols: Column*): GroupedData = new GroupedData(this, 
cols.map(_.expr))
+  def groupBy(cols: Column*): GroupedData = {
+    GroupedData(this, cols.map(_.expr), GroupedData.GroupByType)
+  }
+
+  /**
+   * Create a multi-dimensional rollup for the current [[DataFrame]] using the 
specified columns,
+   * so we can run aggregation on them.
+   * See [[GroupedData]] for all the available aggregate functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns rolluped by department 
and group.
+   *   df.rollup($"department", $"group").avg()
+   *
+   *   // Compute the max age and average salary, rolluped by department and 
gender.
+   *   df.rollup($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   * @group dfops
+   * @since 1.4.0
+   */
+  @scala.annotation.varargs
+  def rollup(cols: Column*): GroupedData = {
+    GroupedData(this, cols.map(_.expr), GroupedData.RollupType)
+  }
+
+  /**
+   * Create a multi-dimensional cube for the current [[DataFrame]] using the 
specified columns,
+   * so we can run aggregation on them.
+   * See [[GroupedData]] for all the available aggregate functions.
+   *
+   * {{{
+   *   // Compute the average for all numeric columns cubed by department and 
group.
+   *   df.cube($"department", $"group").avg()
+   *
+   *   // Compute the max age and average salary, cubed by department and 
gender.
+   *   df.cube($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   * @group dfops
+   * @since 1.4.0
+   */
+  @scala.annotation.varargs
+  def cube(cols: Column*): GroupedData = GroupedData(this, cols.map(_.expr), 
GroupedData.CubeType)
 
   /**
    * Groups the [[DataFrame]] using the specified columns, so we can run 
aggregation on them.
@@ -710,7 +756,61 @@ class DataFrame private[sql](
   @scala.annotation.varargs
   def groupBy(col1: String, cols: String*): GroupedData = {
     val colNames: Seq[String] = col1 +: cols
-    new GroupedData(this, colNames.map(colName => resolve(colName)))
+    GroupedData(this, colNames.map(colName => resolve(colName)), 
GroupedData.GroupByType)
+  }
+
+  /**
+   * Create a multi-dimensional rollup for the current [[DataFrame]] using the 
specified columns,
+   * so we can run aggregation on them.
+   * See [[GroupedData]] for all the available aggregate functions.
+   *
+   * This is a variant of rollup that can only group by existing columns using 
column names
+   * (i.e. cannot construct expressions).
+   *
+   * {{{
+   *   // Compute the average for all numeric columns rolluped by department 
and group.
+   *   df.rollup("department", "group").avg()
+   *
+   *   // Compute the max age and average salary, rolluped by department and 
gender.
+   *   df.rollup($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   * @group dfops
+   * @since 1.4.0
+   */
+  @scala.annotation.varargs
+  def rollup(col1: String, cols: String*): GroupedData = {
+    val colNames: Seq[String] = col1 +: cols
+    GroupedData(this, colNames.map(colName => resolve(colName)), 
GroupedData.RollupType)
+  }
+
+  /**
+   * Create a multi-dimensional cube for the current [[DataFrame]] using the 
specified columns,
+   * so we can run aggregation on them.
+   * See [[GroupedData]] for all the available aggregate functions.
+   *
+   * This is a variant of cube that can only group by existing columns using 
column names
+   * (i.e. cannot construct expressions).
+   *
+   * {{{
+   *   // Compute the average for all numeric columns cubed by department and 
group.
+   *   df.cube("department", "group").avg()
+   *
+   *   // Compute the max age and average salary, cubed by department and 
gender.
+   *   df.cube($"department", $"gender").agg(Map(
+   *     "salary" -> "avg",
+   *     "age" -> "max"
+   *   ))
+   * }}}
+   * @group dfops
+   * @since 1.4.0
+   */
+  @scala.annotation.varargs
+  def cube(col1: String, cols: String*): GroupedData = {
+    val colNames: Seq[String] = col1 +: cols
+    GroupedData(this, colNames.map(colName => resolve(colName)), 
GroupedData.CubeType)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/42c592ad/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 1381b9f..f730e4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,9 +23,40 @@ import scala.language.implicitConversions
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.analysis.Star
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
 import org.apache.spark.sql.types.NumericType
 
+/**
+ * Companion object for GroupedData
+ */
+private[sql] object GroupedData {
+  def apply(
+      df: DataFrame,
+      groupingExprs: Seq[Expression],
+      groupType: GroupType): GroupedData = {
+    new GroupedData(df, groupingExprs, groupType: GroupType)
+  }
+
+  /**
+   * The Grouping Type
+   */
+  trait GroupType
+
+  /**
+   * To indicate it's the GroupBy
+   */
+  object GroupByType extends GroupType
+
+  /**
+   * To indicate it's the CUBE
+   */
+  object CubeType extends GroupType
+
+  /**
+   * To indicate it's the ROLLUP
+   */
+  object RollupType extends GroupType
+}
 
 /**
  * :: Experimental ::
@@ -34,19 +65,37 @@ import org.apache.spark.sql.types.NumericType
  * @since 1.3.0
  */
 @Experimental
-class GroupedData protected[sql](df: DataFrame, groupingExprs: 
Seq[Expression]) {
+class GroupedData protected[sql](
+    df: DataFrame,
+    groupingExprs: Seq[Expression],
+    private val groupType: GroupedData.GroupType) {
 
-  private[sql] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
-    val namedGroupingExprs = groupingExprs.map {
-      case expr: NamedExpression => expr
-      case expr: Expression => Alias(expr, expr.prettyString)()
+  private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
+    val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
+        val retainedExprs = groupingExprs.map {
+          case expr: NamedExpression => expr
+          case expr: Expression => Alias(expr, expr.prettyString)()
+        }
+        retainedExprs ++ aggExprs
+      } else {
+        aggExprs
+      }
+
+    groupType match {
+      case GroupedData.GroupByType =>
+        DataFrame(
+          df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
+      case GroupedData.RollupType =>
+        DataFrame(
+          df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
+      case GroupedData.CubeType =>
+        DataFrame(
+          df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
     }
-    DataFrame(
-      df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, 
df.logicalPlan))
   }
 
   private[this] def aggregateNumericColumns(colNames: String*)(f: Expression 
=> Expression)
-    : Seq[NamedExpression] = {
+    : DataFrame = {
 
     val columnExprs = if (colNames.isEmpty) {
       // No columns specified. Use all numeric columns.
@@ -63,10 +112,10 @@ class GroupedData protected[sql](df: DataFrame, 
groupingExprs: Seq[Expression])
         namedExpr
       }
     }
-    columnExprs.map { c =>
+    toDF(columnExprs.map { c =>
       val a = f(c)
       Alias(a, a.prettyString)()
-    }
+    })
   }
 
   private[this] def strToExpr(expr: String): (Expression => Expression) = {
@@ -119,10 +168,10 @@ class GroupedData protected[sql](df: DataFrame, 
groupingExprs: Seq[Expression])
    * @since 1.3.0
    */
   def agg(exprs: Map[String, String]): DataFrame = {
-    exprs.map { case (colName, expr) =>
+    toDF(exprs.map { case (colName, expr) =>
       val a = strToExpr(expr)(df(colName).expr)
       Alias(a, a.prettyString)()
-    }.toSeq
+    }.toSeq)
   }
 
   /**
@@ -175,19 +224,10 @@ class GroupedData protected[sql](df: DataFrame, 
groupingExprs: Seq[Expression])
    */
   @scala.annotation.varargs
   def agg(expr: Column, exprs: Column*): DataFrame = {
-    val aggExprs = (expr +: exprs).map(_.expr).map {
+    toDF((expr +: exprs).map(_.expr).map {
       case expr: NamedExpression => expr
       case expr: Expression => Alias(expr, expr.prettyString)()
-    }
-    if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
-      val retainedExprs = groupingExprs.map {
-        case expr: NamedExpression => expr
-        case expr: Expression => Alias(expr, expr.prettyString)()
-      }
-      DataFrame(df.sqlContext, Aggregate(groupingExprs, retainedExprs ++ 
aggExprs, df.logicalPlan))
-    } else {
-      DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, 
df.logicalPlan))
-    }
+    })
   }
 
   /**
@@ -196,7 +236,7 @@ class GroupedData protected[sql](df: DataFrame, 
groupingExprs: Seq[Expression])
    *
    * @since 1.3.0
    */
-  def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
+  def count(): DataFrame = toDF(Seq(Alias(Count(Literal(1)), "count")()))
 
   /**
    * Compute the average value for each numeric columns for each group. This 
is an alias for `avg`.
@@ -256,5 +296,5 @@ class GroupedData protected[sql](df: DataFrame, 
groupingExprs: Seq[Expression])
   @scala.annotation.varargs
   def sum(colNames: String*): DataFrame = {
     aggregateNumericColumns(colNames:_*)(Sum)
-  }    
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/42c592ad/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
new file mode 100644
index 0000000..99de146
--- /dev/null
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.scalatest.BeforeAndAfterAll
+
+case class TestData2Int(a: Int, b: Int)
+
+// TODO ideally we should put the test suite into the package `sql`, as
+// `hive` package is optional in compiling, however, `SQLContext.sql` doesn't
+// support the `cube` or `rollup` yet.
+class HiveDataFrameAnalyticsSuite extends QueryTest with BeforeAndAfterAll {
+  val testData =
+    TestHive.sparkContext.parallelize(
+      TestData2Int(1, 2) ::
+        TestData2Int(2, 4) :: Nil).toDF()
+
+  override def beforeAll() {
+    TestHive.registerDataFrameAsTable(testData, "mytable")
+  }
+
+  override def afterAll(): Unit = {
+    TestHive.dropTempTable("mytable")
+  }
+
+  test("rollup") {
+    checkAnswer(
+      testData.rollup($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+      sql("select a + b, b, sum(a - b) from mytable group by a + b, b with 
rollup").collect()
+    )
+
+    checkAnswer(
+      testData.rollup("a", "b").agg(sum("b")),
+      sql("select a, b, sum(b) from mytable group by a, b with 
rollup").collect()
+    )
+  }
+
+  test("cube") {
+    checkAnswer(
+      testData.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")),
+      sql("select a + b, b, sum(a - b) from mytable group by a + b, b with 
cube").collect()
+    )
+
+    checkAnswer(
+      testData.cube("a", "b").agg(sum("b")),
+      sql("select a, b, sum(b) from mytable group by a, b with cube").collect()
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to