Repository: spark
Updated Branches:
  refs/heads/branch-1.6 b79c1bd1e -> ff7d869c4


[SPARK-9830][SPARK-11641][SQL][FOLLOW-UP] Remove AggregateExpression1 and 
update toString of Exchange

https://issues.apache.org/jira/browse/SPARK-9830

This is the follow-up pr for https://github.com/apache/spark/pull/9556 to 
address davies' comments.

Author: Yin Huai <yh...@databricks.com>

Closes #9607 from yhuai/removeAgg1-followup.

(cherry picked from commit 3121e78168808c015fb21da8b0d44bb33649fb81)
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff7d869c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff7d869c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff7d869c

Branch: refs/heads/branch-1.6
Commit: ff7d869c47d8fbb88d0a0e3f9c431e6ff1e45390
Parents: b79c1bd
Author: Yin Huai <yh...@databricks.com>
Authored: Tue Nov 10 16:25:22 2015 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Tue Nov 10 16:25:29 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |   2 +-
 .../sql/catalyst/analysis/CheckAnalysis.scala   |  58 ++++++---
 .../expressions/aggregate/Average.scala         |   2 +-
 .../aggregate/CentralMomentAgg.scala            |   2 +-
 .../catalyst/expressions/aggregate/Stddev.scala |   2 +-
 .../catalyst/expressions/aggregate/Sum.scala    |   2 +-
 .../catalyst/analysis/AnalysisErrorSuite.scala  | 127 +++++++++++++++----
 .../scala/org/apache/spark/sql/SQLConf.scala    |   1 +
 .../apache/spark/sql/execution/Exchange.scala   |   8 +-
 .../apache/spark/sql/execution/commands.scala   |  10 ++
 10 files changed, 160 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b1e1439..a9cd9a7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -532,7 +532,7 @@ class Analyzer(
                 case min: Min if isDistinct =>
                   AggregateExpression(min, Complete, isDistinct = false)
                 // We get an aggregate function, we need to wrap it in an 
AggregateExpression.
-                case agg2: AggregateFunction => AggregateExpression(agg2, 
Complete, isDistinct)
+                case agg: AggregateFunction => AggregateExpression(agg, 
Complete, isDistinct)
                 // This function is not an aggregate function, just return the 
resolved one.
                 case other => other
               }

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 8322e99..5a4b0c1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -110,17 +110,21 @@ trait CheckAnalysis {
           case Aggregate(groupingExprs, aggregateExprs, child) =>
             def checkValidAggregateExpression(expr: Expression): Unit = expr 
match {
               case aggExpr: AggregateExpression =>
-                // TODO: Is it possible that the child of a agg function is 
another
-                // agg function?
-                aggExpr.aggregateFunction.children.foreach {
-                  // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
-                  // already pull out those nondeterministic expressions and 
evaluate them in
-                  // a Project node.
-                  case child if !child.deterministic =>
+                aggExpr.aggregateFunction.children.foreach { child =>
+                  child.foreach {
+                    case agg: AggregateExpression =>
+                      failAnalysis(
+                        s"It is not allowed to use an aggregate function in 
the argument of " +
+                          s"another aggregate function. Please use the inner 
aggregate function " +
+                          s"in a sub-query.")
+                    case other => // OK
+                  }
+
+                  if (!child.deterministic) {
                     failAnalysis(
                       s"nondeterministic expression ${expr.prettyString} 
should not " +
                         s"appear in the arguments of an aggregate function.")
-                  case child => // OK
+                  }
                 }
               case e: Attribute if !groupingExprs.exists(_.semanticEquals(e)) 
=>
                 failAnalysis(
@@ -133,19 +137,33 @@ trait CheckAnalysis {
               case e => e.children.foreach(checkValidAggregateExpression)
             }
 
+            def checkSupportedGroupingDataType(
+                expressionString: String,
+                dataType: DataType): Unit = dataType match {
+              case BinaryType =>
+                failAnalysis(s"expression $expressionString cannot be used in 
" +
+                  s"grouping expression because it is in binary type or its 
inner field is " +
+                  s"in binary type")
+              case a: ArrayType =>
+                failAnalysis(s"expression $expressionString cannot be used in 
" +
+                  s"grouping expression because it is in array type or its 
inner field is " +
+                  s"in array type")
+              case m: MapType =>
+                failAnalysis(s"expression $expressionString cannot be used in 
" +
+                  s"grouping expression because it is in map type or its inner 
field is " +
+                  s"in map type")
+              case s: StructType =>
+                s.fields.foreach { f =>
+                  checkSupportedGroupingDataType(expressionString, f.dataType)
+                }
+              case udt: UserDefinedType[_] =>
+                checkSupportedGroupingDataType(expressionString, udt.sqlType)
+              case _ => // OK
+            }
+
             def checkValidGroupingExprs(expr: Expression): Unit = {
-              expr.dataType match {
-                case BinaryType =>
-                  failAnalysis(s"binary type expression ${expr.prettyString} 
cannot be used " +
-                    "in grouping expression")
-                case a: ArrayType =>
-                  failAnalysis(s"array type expression ${expr.prettyString} 
cannot be used " +
-                    "in grouping expression")
-                case m: MapType =>
-                  failAnalysis(s"map type expression ${expr.prettyString} 
cannot be used " +
-                    "in grouping expression")
-                case _ => // OK
-              }
+              checkSupportedGroupingDataType(expr.prettyString, expr.dataType)
+
               if (!expr.deterministic) {
                 // This is just a sanity check, our analysis rule 
PullOutNondeterministic should
                 // already pull out those nondeterministic expressions and 
evaluate them in

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 7f9e503..94ac4bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -34,7 +34,7 @@ case class Average(child: Expression) extends 
DeclarativeAggregate {
   // Return data type.
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function average")

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 984ce7f..de5872a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends 
ImperativeAggregate w
 
   override def dataType: DataType = DoubleType
 
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName")

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
index 5b9eb7a..2748009 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala
@@ -50,7 +50,7 @@ abstract class StddevAgg(child: Expression) extends 
DeclarativeAggregate {
 
   override def dataType: DataType = resultType
 
-  override def inputTypes: Seq[AbstractDataType] = 
Seq(TypeCollection(NumericType))
+  override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function stddev")

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index c005ec9..cfb042e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -32,7 +32,7 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate {
   override def dataType: DataType = resultType
 
   override def inputTypes: Seq[AbstractDataType] =
-    Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))
+    Seq(TypeCollection(LongType, DoubleType, DecimalType))
 
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForNumericExpr(child.dataType, "function sum")

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 5a2368e..2e7c3bd 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,8 +23,59 @@ import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData}
 import org.apache.spark.sql.types._
 
+import scala.beans.{BeanProperty, BeanInfo}
+
+@BeanInfo
+private[sql] case class GroupableData(@BeanProperty data: Int)
+
+private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {
+
+  override def sqlType: DataType = IntegerType
+
+  override def serialize(obj: Any): Int = {
+    obj match {
+      case groupableData: GroupableData => groupableData.data
+    }
+  }
+
+  override def deserialize(datum: Any): GroupableData = {
+    datum match {
+      case data: Int => GroupableData(data)
+    }
+  }
+
+  override def userClass: Class[GroupableData] = classOf[GroupableData]
+
+  private[spark] override def asNullable: GroupableUDT = this
+}
+
+@BeanInfo
+private[sql] case class UngroupableData(@BeanProperty data: Array[Int])
+
+private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {
+
+  override def sqlType: DataType = ArrayType(IntegerType)
+
+  override def serialize(obj: Any): ArrayData = {
+    obj match {
+      case groupableData: UngroupableData => new 
GenericArrayData(groupableData.data)
+    }
+  }
+
+  override def deserialize(datum: Any): UngroupableData = {
+    datum match {
+      case data: Array[Int] => UngroupableData(data)
+    }
+  }
+
+  override def userClass: Class[UngroupableData] = classOf[UngroupableData]
+
+  private[spark] override def asNullable: UngroupableUDT = this
+}
+
 case class TestFunction(
     children: Seq[Expression],
     inputTypes: Seq[AbstractDataType])
@@ -194,39 +245,65 @@ class AnalysisErrorSuite extends AnalysisTest {
     assert(error.message.contains("Conflicting attributes"))
   }
 
-  test("aggregation can't work on binary and map types") {
-    val plan =
-      Aggregate(
-        AttributeReference("a", BinaryType)(exprId = ExprId(2)) :: Nil,
-        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), 
"c")() :: Nil,
-        LocalRelation(
-          AttributeReference("a", BinaryType)(exprId = ExprId(2)),
-          AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+  test("check grouping expression data types") {
+    def checkDataType(dataType: DataType, shouldSuccess: Boolean): Unit = {
+      val plan =
+        Aggregate(
+          AttributeReference("a", dataType)(exprId = ExprId(2)) :: Nil,
+          Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), 
"c")() :: Nil,
+          LocalRelation(
+            AttributeReference("a", dataType)(exprId = ExprId(2)),
+            AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+
+      shouldSuccess match {
+        case true =>
+          assertAnalysisSuccess(plan, true)
+        case false =>
+          assertAnalysisError(plan, "expression a cannot be used in grouping 
expression" :: Nil)
+      }
 
-    assertAnalysisError(plan,
-      "binary type expression a cannot be used in grouping expression" :: Nil)
+    }
 
-    val plan2 =
-      Aggregate(
-        AttributeReference("a", MapType(IntegerType, StringType))(exprId = 
ExprId(2)) :: Nil,
-        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), 
"c")() :: Nil,
-        LocalRelation(
-          AttributeReference("a", MapType(IntegerType, StringType))(exprId = 
ExprId(2)),
-          AttributeReference("b", IntegerType)(exprId = ExprId(1))))
+    val supportedDataTypes = Seq(
+      StringType,
+      NullType, BooleanType,
+      ByteType, ShortType, IntegerType, LongType,
+      FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+      DateType, TimestampType,
+      new StructType()
+        .add("f1", FloatType, nullable = true)
+        .add("f2", StringType, nullable = true),
+      new GroupableUDT())
+    supportedDataTypes.foreach { dataType =>
+      checkDataType(dataType, shouldSuccess = true)
+    }
 
-    assertAnalysisError(plan2,
-      "map type expression a cannot be used in grouping expression" :: Nil)
+    val unsupportedDataTypes = Seq(
+      BinaryType,
+      ArrayType(IntegerType),
+      MapType(StringType, LongType),
+      new StructType()
+        .add("f1", FloatType, nullable = true)
+        .add("f2", ArrayType(BooleanType, containsNull = true), nullable = 
true),
+      new UngroupableUDT())
+    unsupportedDataTypes.foreach { dataType =>
+      checkDataType(dataType, shouldSuccess = false)
+    }
+  }
 
-    val plan3 =
+  test("we should fail analysis when we find nested aggregate functions") {
+    val plan =
       Aggregate(
-        AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)) :: 
Nil,
-        Alias(sum(AttributeReference("b", IntegerType)(exprId = ExprId(1))), 
"c")() :: Nil,
+        AttributeReference("a", IntegerType)(exprId = ExprId(2)) :: Nil,
+        Alias(sum(sum(AttributeReference("b", IntegerType)(exprId = 
ExprId(1)))), "c")() :: Nil,
         LocalRelation(
-          AttributeReference("a", ArrayType(IntegerType))(exprId = ExprId(2)),
+          AttributeReference("a", IntegerType)(exprId = ExprId(2)),
           AttributeReference("b", IntegerType)(exprId = ExprId(1))))
 
-    assertAnalysisError(plan3,
-      "array type expression a cannot be used in grouping expression" :: Nil)
+    assertAnalysisError(
+      plan,
+      "It is not allowed to use an aggregate function in the argument of " +
+        "another aggregate function." :: Nil)
   }
 
   test("Join can't work on binary and map types") {

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 89e196c..57d7d30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -474,6 +474,7 @@ private[spark] object SQLConf {
   object Deprecated {
     val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
     val EXTERNAL_SORT = "spark.sql.planner.externalSort"
+    val USE_SQL_AGGREGATE2 = "spark.sql.useAggregate2"
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index a4ce328..b733b26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -44,14 +44,14 @@ case class Exchange(
   override def nodeName: String = {
     val extraInfo = coordinator match {
       case Some(exchangeCoordinator) if exchangeCoordinator.isEstimated =>
-        "Shuffle"
+        s"(coordinator id: ${System.identityHashCode(coordinator)})"
       case Some(exchangeCoordinator) if !exchangeCoordinator.isEstimated =>
-        "May shuffle"
-      case None => "Shuffle without coordinator"
+        s"(coordinator id: ${System.identityHashCode(coordinator)})"
+      case None => ""
     }
 
     val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange"
-    s"$simpleNodeName($extraInfo)"
+    s"${simpleNodeName}${extraInfo}"
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ff7d869c/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index e5f60b1..8b2755a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -111,6 +111,16 @@ case class SetCommand(kv: Option[(String, 
Option[String])]) extends RunnableComm
       }
       (keyValueOutput, runFunc)
 
+    case Some((SQLConf.Deprecated.USE_SQL_AGGREGATE2, Some(value))) =>
+      val runFunc = (sqlContext: SQLContext) => {
+        logWarning(
+          s"Property ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} is deprecated 
and " +
+            s"will be ignored. ${SQLConf.Deprecated.USE_SQL_AGGREGATE2} will " 
+
+            s"continue to be true.")
+        Seq(Row(SQLConf.Deprecated.USE_SQL_AGGREGATE2, "true"))
+      }
+      (keyValueOutput, runFunc)
+
     // Configures a single property.
     case Some((key, Some(value))) =>
       val runFunc = (sqlContext: SQLContext) => {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to