Repository: spark
Updated Branches:
  refs/heads/branch-1.6 0419fd361 -> 3f40af574


[SPARK-11946][SQL] Audit pivot API for 1.6.

Currently pivot's signature looks like

```scala
scala.annotation.varargs
def pivot(pivotColumn: Column, values: Column*): GroupedData

scala.annotation.varargs
def pivot(pivotColumn: String, values: Any*): GroupedData
```

I think we can remove the one that takes "Column" types, since callers should 
always be passing in literals. It'd also be more clear if the values are not 
varargs, but rather Seq or java.util.List.

I also made similar changes for Python.

Author: Reynold Xin <r...@databricks.com>

Closes #9929 from rxin/SPARK-11946.

(cherry picked from commit f3152722791b163fa66597b3684009058195ba33)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3f40af57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3f40af57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3f40af57

Branch: refs/heads/branch-1.6
Commit: 3f40af574ed14909d30c10cad24d747fba49b783
Parents: 0419fd3
Author: Reynold Xin <r...@databricks.com>
Authored: Tue Nov 24 12:54:37 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Tue Nov 24 12:54:44 2015 -0800

----------------------------------------------------------------------
 .../apache/spark/scheduler/DAGScheduler.scala   |   1 -
 python/pyspark/sql/group.py                     |  12 +-
 .../sql/catalyst/expressions/literals.scala     |   1 +
 .../org/apache/spark/sql/GroupedData.scala      | 154 +++++++++++--------
 .../apache/spark/sql/JavaDataFrameSuite.java    |  16 ++
 .../apache/spark/sql/DataFramePivotSuite.scala  |  21 +--
 .../org/apache/spark/sql/test/SQLTestData.scala |   1 +
 7 files changed, 125 insertions(+), 81 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ae725b4..77a184d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1574,7 +1574,6 @@ class DAGScheduler(
   }
 
   def stop() {
-    logInfo("Stopping DAGScheduler")
     messageScheduler.shutdownNow()
     eventProcessLoop.stop()
     taskScheduler.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/python/pyspark/sql/group.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 227f40b..d8ed7eb 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -168,20 +168,24 @@ class GroupedData(object):
         """
 
     @since(1.6)
-    def pivot(self, pivot_col, *values):
+    def pivot(self, pivot_col, values=None):
         """Pivots a column of the current DataFrame and preform the specified 
aggregation.
 
         :param pivot_col: Column to pivot
         :param values: Optional list of values of pivotColumn that will be 
translated to columns in
             the output data frame. If values are not provided the method with 
do an immediate call
             to .distinct() on the pivot column.
-        >>> df4.groupBy("year").pivot("course", "dotNET", 
"Java").sum("earnings").collect()
+
+        >>> df4.groupBy("year").pivot("course", ["dotNET", 
"Java"]).sum("earnings").collect()
         [Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, 
dotNET=48000, Java=30000)]
+
         >>> df4.groupBy("year").pivot("course").sum("earnings").collect()
         [Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, 
dotNET=48000)]
         """
-        jgd = self._jdf.pivot(_to_java_column(pivot_col),
-                              _to_seq(self.sql_ctx._sc, values, 
_create_column_from_literal))
+        if values is None:
+            jgd = self._jdf.pivot(pivot_col)
+        else:
+            jgd = self._jdf.pivot(pivot_col, values)
         return GroupedData(jgd, self.sql_ctx)
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index e34fd49..68ec688 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -44,6 +44,7 @@ object Literal {
     case a: Array[Byte] => Literal(a, BinaryType)
     case i: CalendarInterval => Literal(i, CalendarIntervalType)
     case null => Literal(null, NullType)
+    case v: Literal => v
     case _ =>
       throw new RuntimeException("Unsupported literal type " + v.getClass + " 
" + v)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 63dd7fb..ee7150c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -25,7 +25,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, 
Aggregate}
-import org.apache.spark.sql.types.{StringType, NumericType}
+import org.apache.spark.sql.types.NumericType
 
 
 /**
@@ -282,74 +282,96 @@ class GroupedData protected[sql](
   }
 
   /**
-    * (Scala-specific) Pivots a column of the current [[DataFrame]] and 
preform the specified
-    * aggregation.
-    * {{{
-    *   // Compute the sum of earnings for each year by course with each 
course as a separate column
-    *   df.groupBy($"year").pivot($"course", "dotNET", 
"Java").agg(sum($"earnings"))
-    *   // Or without specifying column values
-    *   df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
-    * }}}
-    * @param pivotColumn Column to pivot
-    * @param values Optional list of values of pivotColumn that will be 
translated to columns in the
-    *               output data frame. If values are not provided the method 
with do an immediate
-    *               call to .distinct() on the pivot column.
-    * @since 1.6.0
-    */
-  @scala.annotation.varargs
-  def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType 
match {
-    case _: GroupedData.PivotType =>
-      throw new UnsupportedOperationException("repeated pivots are not 
supported")
-    case GroupedData.GroupByType =>
-      val pivotValues = if (values.nonEmpty) {
-        values.map {
-          case Column(literal: Literal) => literal
-          case other =>
-            throw new UnsupportedOperationException(
-              s"The values of a pivot must be literals, found $other")
-        }
-      } else {
-        // This is to prevent unintended OOM errors when the number of 
distinct values is large
-        val maxValues = 
df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
-        // Get the distinct values of the column and sort them so its 
consistent
-        val values = df.select(pivotColumn)
-          .distinct()
-          .sort(pivotColumn)
-          .map(_.get(0))
-          .take(maxValues + 1)
-          .map(Literal(_)).toSeq
-        if (values.length > maxValues) {
-          throw new RuntimeException(
-            s"The pivot column $pivotColumn has more than $maxValues distinct 
values, " +
-              "this could indicate an error. " +
-              "If this was intended, set \"" + 
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
-              s"to at least the number of distinct values of the pivot 
column.")
-        }
-        values
-      }
-      new GroupedData(df, groupingExprs, 
GroupedData.PivotType(pivotColumn.expr, pivotValues))
-    case _ =>
-      throw new UnsupportedOperationException("pivot is only supported after a 
groupBy")
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String): GroupedData = {
+    // This is to prevent unintended OOM errors when the number of distinct 
values is large
+    val maxValues = 
df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+    // Get the distinct values of the column and sort them so its consistent
+    val values = df.select(pivotColumn)
+      .distinct()
+      .sort(pivotColumn)
+      .map(_.get(0))
+      .take(maxValues + 1)
+      .toSeq
+
+    if (values.length > maxValues) {
+      throw new AnalysisException(
+        s"The pivot column $pivotColumn has more than $maxValues distinct 
values, " +
+          "this could indicate an error. " +
+          s"If this was intended, set 
${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
+          "to at least the number of distinct values of the pivot column.")
+    }
+
+    pivot(pivotColumn, values)
   }
 
   /**
-    * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
-    * {{{
-    *   // Compute the sum of earnings for each year by course with each 
course as a separate column
-    *   df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
-    *   // Or without specifying column values
-    *   df.groupBy("year").pivot("course").sum("earnings")
-    * }}}
-    * @param pivotColumn Column to pivot
-    * @param values Optional list of values of pivotColumn that will be 
translated to columns in the
-    *               output data frame. If values are not provided the method 
with do an immediate
-    *               call to .distinct() on the pivot column.
-    * @since 1.6.0
-    */
-  @scala.annotation.varargs
-  def pivot(pivotColumn: String, values: Any*): GroupedData = {
-    val resolvedPivotColumn = Column(df.resolve(pivotColumn))
-    pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Seq("dotNET", 
"Java")).sum("earnings")
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings")
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @param values List of values that will be translated to columns in the 
output DataFrame.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
+    groupType match {
+      case GroupedData.GroupByType =>
+        new GroupedData(
+          df,
+          groupingExprs,
+          GroupedData.PivotType(df.resolve(pivotColumn), 
values.map(Literal.apply)))
+      case _: GroupedData.PivotType =>
+        throw new UnsupportedOperationException("repeated pivots are not 
supported")
+      case _ =>
+        throw new UnsupportedOperationException("pivot is only supported after 
a groupBy")
+    }
+  }
+
+  /**
+   * Pivots a column of the current [[DataFrame]] and preform the specified 
aggregation.
+   * There are two versions of pivot function: one that requires the caller to 
specify the list
+   * of distinct values to pivot on, and one that does not. The latter is more 
concise but less
+   * efficient, because Spark needs to first compute the list of distinct 
values internally.
+   *
+   * {{{
+   *   // Compute the sum of earnings for each year by course with each course 
as a separate column
+   *   df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", 
"Java")).sum("earnings");
+   *
+   *   // Or without specifying column values (less efficient)
+   *   df.groupBy("year").pivot("course").sum("earnings");
+   * }}}
+   *
+   * @param pivotColumn Name of the column to pivot.
+   * @param values List of values that will be translated to columns in the 
output DataFrame.
+   * @since 1.6.0
+   */
+  def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
+    pivot(pivotColumn, values.asScala)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 567bddd..a12fed3 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -282,4 +282,20 @@ public class JavaDataFrameSuite {
     Assert.assertEquals(1, actual[1].getLong(0));
     Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
   }
+
+  @Test
+  public void pivot() {
+    DataFrame df = context.table("courseSales");
+    Row[] actual = df.groupBy("year")
+      .pivot("course", Arrays.<Object>asList("dotNET", "Java"))
+      .agg(sum("earnings")).orderBy("year").collect();
+
+    Assert.assertEquals(2012, actual[0].getInt(0));
+    Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
+    Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
+
+    Assert.assertEquals(2013, actual[1].getInt(0));
+    Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
+    Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
index 0c23d14..fc53aba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -25,7 +25,7 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot courses with literals") {
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+      courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
         .agg(sum($"earnings")),
       Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
     )
@@ -33,14 +33,15 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot year with literals") {
     checkAnswer(
-      courseSales.groupBy($"course").pivot($"year", lit(2012), 
lit(2013)).agg(sum($"earnings")),
+      courseSales.groupBy("course").pivot("year", Seq(2012, 
2013)).agg(sum($"earnings")),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
   test("pivot courses with literals and multiple aggregations") {
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+      courseSales.groupBy($"year")
+        .pivot("course", Seq("dotNET", "Java"))
         .agg(sum($"earnings"), avg($"earnings")),
       Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
         Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
@@ -49,14 +50,14 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
 
   test("pivot year with string values (cast)") {
     checkAnswer(
-      courseSales.groupBy("course").pivot("year", "2012", 
"2013").sum("earnings"),
+      courseSales.groupBy("course").pivot("year", Seq("2012", 
"2013")).sum("earnings"),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
   test("pivot year with int values") {
     checkAnswer(
-      courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+      courseSales.groupBy("course").pivot("year", Seq(2012, 
2013)).sum("earnings"),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
@@ -64,22 +65,22 @@ class DataFramePivotSuite extends QueryTest with 
SharedSQLContext{
   test("pivot courses with no values") {
     // Note Java comes before dotNet in sorted order
     checkAnswer(
-      courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+      courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
       Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
     )
   }
 
   test("pivot year with no values") {
     checkAnswer(
-      courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+      courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
       Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
     )
   }
 
-  test("pivot max values inforced") {
+  test("pivot max values enforced") {
     sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
-    intercept[RuntimeException](
-      courseSales.groupBy($"year").pivot($"course")
+    intercept[AnalysisException](
+      courseSales.groupBy("year").pivot("course")
     )
     sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
       SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)

http://git-wip-us.apache.org/repos/asf/spark/blob/3f40af57/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index abad0d7..83c63e0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
     person
     salary
     complexData
+    courseSales
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to