This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e2c487  [SPARK-26448][SQL][FOLLOWUP] should not normalize grouping 
expressions for final aggregate
0e2c487 is described below

commit 0e2c4874596269dd835bf69a5592b316345597c5
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu Jan 31 16:20:18 2019 +0800

    [SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expressions for 
final aggregate
    
    ## What changes were proposed in this pull request?
    
    A followup of https://github.com/apache/spark/pull/23388 .
    
    `AggUtils.createAggregate` is not the right place to normalize the grouping 
expressions, as final aggregate is also created by it. The grouping expressions 
of final aggregate should be attributes which refer to the grouping expressions 
in partial aggregate.
    
    This PR moves the normalization to the caller side of `AggUtils`.
    
    ## How was this patch tested?
    
    existing tests
    
    Closes #23692 from cloud-fan/follow.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../optimizer/NormalizeFloatingNumbers.scala       | 16 ++++++++------
 .../spark/sql/execution/SparkStrategies.scala      | 25 +++++++++++++++++++---
 .../spark/sql/execution/aggregate/AggUtils.scala   | 14 +++---------
 3 files changed, 34 insertions(+), 21 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
index 520f24a..a5921eb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala
@@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
   }
 
   private[sql] def normalize(expr: Expression): Expression = expr match {
-    case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
-      NormalizeNaNAndZero(expr)
+    case _ if !needNormalize(expr.dataType) => expr
+
+    case a: Alias =>
+      a.withNewChildren(Seq(normalize(a.child)))
 
     case CreateNamedStruct(children) =>
       CreateNamedStruct(children.map(normalize))
@@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] 
{
     case CreateMap(children) =>
       CreateMap(children.map(normalize))
 
-    case a: Alias if needNormalize(a.dataType) =>
-      a.withNewChildren(Seq(normalize(a.child)))
+    case _ if expr.dataType == FloatType || expr.dataType == DoubleType =>
+      NormalizeNaNAndZero(expr)
 
-    case _ if expr.dataType.isInstanceOf[StructType] && 
needNormalize(expr.dataType) =>
+    case _ if expr.dataType.isInstanceOf[StructType] =>
       val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { 
i =>
         normalize(GetStructField(expr, i))
       }
       CreateStruct(fields)
 
-    case _ if expr.dataType.isInstanceOf[ArrayType] && 
needNormalize(expr.dataType) =>
+    case _ if expr.dataType.isInstanceOf[ArrayType] =>
       val ArrayType(et, containsNull) = expr.dataType
       val lv = NamedLambdaVariable("arg", et, containsNull)
       val function = normalize(lv)
       ArrayTransform(expr, LambdaFunction(function, Seq(lv)))
 
-    case _ => expr
+    case _ => throw new IllegalStateException(s"fail to normalize $expr")
   }
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b7cc373..edfa704 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
 import org.apache.spark.sql.catalyst.planning._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -331,8 +332,17 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
 
         val stateVersion = 
conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)
 
+        // Ideally this should be done in `NormalizeFloatingNumbers`, but we 
do it here because
+        // `groupingExpressions` is not extracted during logical phase.
+        val normalizedGroupingExpressions = namedGroupingExpressions.map { e =>
+          NormalizeFloatingNumbers.normalize(e) match {
+            case n: NamedExpression => n
+            case other => Alias(other, e.name)(exprId = e.exprId)
+          }
+        }
+
         aggregate.AggUtils.planStreamingAggregation(
-          namedGroupingExpressions,
+          normalizedGroupingExpressions,
           aggregateExpressions.map(expr => 
expr.asInstanceOf[AggregateExpression]),
           rewrittenResultExpressions,
           stateVersion,
@@ -414,16 +424,25 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
               "Spark user mailing list.")
         }
 
+        // Ideally this should be done in `NormalizeFloatingNumbers`, but we 
do it here because
+        // `groupingExpressions` is not extracted during logical phase.
+        val normalizedGroupingExpressions = groupingExpressions.map { e =>
+          NormalizeFloatingNumbers.normalize(e) match {
+            case n: NamedExpression => n
+            case other => Alias(other, e.name)(exprId = e.exprId)
+          }
+        }
+
         val aggregateOperator =
           if (functionsWithDistinct.isEmpty) {
             aggregate.AggUtils.planAggregateWithoutDistinct(
-              groupingExpressions,
+              normalizedGroupingExpressions,
               aggregateExpressions,
               resultExpressions,
               planLater(child))
           } else {
             aggregate.AggUtils.planAggregateWithOneDistinct(
-              groupingExpressions,
+              normalizedGroupingExpressions,
               functionsWithDistinct,
               functionsWithoutDistinct,
               resultExpressions,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
index 8b7556b..4d762c5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala
@@ -35,20 +35,12 @@ object AggUtils {
       initialInputBufferOffset: Int = 0,
       resultExpressions: Seq[NamedExpression] = Nil,
       child: SparkPlan): SparkPlan = {
-    // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it 
here because
-    // `groupingExpressions` is not extracted during logical phase.
-    val normalizedGroupingExpressions = groupingExpressions.map { e =>
-      NormalizeFloatingNumbers.normalize(e) match {
-        case n: NamedExpression => n
-        case other => Alias(other, e.name)(exprId = e.exprId)
-      }
-    }
     val useHash = HashAggregateExec.supportsAggregate(
       aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))
     if (useHash) {
       HashAggregateExec(
         requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
-        groupingExpressions = normalizedGroupingExpressions,
+        groupingExpressions = groupingExpressions,
         aggregateExpressions = aggregateExpressions,
         aggregateAttributes = aggregateAttributes,
         initialInputBufferOffset = initialInputBufferOffset,
@@ -61,7 +53,7 @@ object AggUtils {
       if (objectHashEnabled && useObjectHash) {
         ObjectHashAggregateExec(
           requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
-          groupingExpressions = normalizedGroupingExpressions,
+          groupingExpressions = groupingExpressions,
           aggregateExpressions = aggregateExpressions,
           aggregateAttributes = aggregateAttributes,
           initialInputBufferOffset = initialInputBufferOffset,
@@ -70,7 +62,7 @@ object AggUtils {
       } else {
         SortAggregateExec(
           requiredChildDistributionExpressions = 
requiredChildDistributionExpressions,
-          groupingExpressions = normalizedGroupingExpressions,
+          groupingExpressions = groupingExpressions,
           aggregateExpressions = aggregateExpressions,
           aggregateAttributes = aggregateAttributes,
           initialInputBufferOffset = initialInputBufferOffset,


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to