This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0e2c487 [SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expressions for final aggregate 0e2c487 is described below commit 0e2c4874596269dd835bf69a5592b316345597c5 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Thu Jan 31 16:20:18 2019 +0800 [SPARK-26448][SQL][FOLLOWUP] should not normalize grouping expressions for final aggregate ## What changes were proposed in this pull request? A followup of https://github.com/apache/spark/pull/23388 . `AggUtils.createAggregate` is not the right place to normalize the grouping expressions, as final aggregate is also created by it. The grouping expressions of final aggregate should be attributes which refer to the grouping expressions in partial aggregate. This PR moves the normalization to the caller side of `AggUtils`. ## How was this patch tested? existing tests Closes #23692 from cloud-fan/follow. Authored-by: Wenchen Fan <wenc...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../optimizer/NormalizeFloatingNumbers.scala | 16 ++++++++------ .../spark/sql/execution/SparkStrategies.scala | 25 +++++++++++++++++++--- .../spark/sql/execution/aggregate/AggUtils.scala | 14 +++--------- 3 files changed, 34 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 520f24a..a5921eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -98,8 +98,10 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { } private[sql] def normalize(expr: Expression): Expression = expr match { - case _ if expr.dataType == FloatType || expr.dataType == DoubleType => - NormalizeNaNAndZero(expr) + case _ if !needNormalize(expr.dataType) => expr + + case a: Alias => + a.withNewChildren(Seq(normalize(a.child))) case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) @@ -113,22 +115,22 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateMap(children) => CreateMap(children.map(normalize)) - case a: Alias if needNormalize(a.dataType) => - a.withNewChildren(Seq(normalize(a.child))) + case _ if expr.dataType == FloatType || expr.dataType == DoubleType => + NormalizeNaNAndZero(expr) - case _ if expr.dataType.isInstanceOf[StructType] && needNormalize(expr.dataType) => + case _ if expr.dataType.isInstanceOf[StructType] => val fields = expr.dataType.asInstanceOf[StructType].fields.indices.map { i => normalize(GetStructField(expr, i)) } CreateStruct(fields) - case _ if expr.dataType.isInstanceOf[ArrayType] && needNormalize(expr.dataType) => + case _ if expr.dataType.isInstanceOf[ArrayType] => val ArrayType(et, containsNull) = expr.dataType val lv = NamedLambdaVariable("arg", et, containsNull) val function = normalize(lv) ArrayTransform(expr, LambdaFunction(function, Seq(lv))) - case _ => expr + case _ => throw new IllegalStateException(s"fail to normalize $expr") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b7cc373..edfa704 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -331,8 +332,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = namedGroupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + aggregate.AggUtils.planStreamingAggregation( - namedGroupingExpressions, + normalizedGroupingExpressions, aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]), rewrittenResultExpressions, stateVersion, @@ -414,16 +424,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Spark user mailing list.") } + // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because + // `groupingExpressions` is not extracted during logical phase. + val normalizedGroupingExpressions = groupingExpressions.map { e => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } + } + val aggregateOperator = if (functionsWithDistinct.isEmpty) { aggregate.AggUtils.planAggregateWithoutDistinct( - groupingExpressions, + normalizedGroupingExpressions, aggregateExpressions, resultExpressions, planLater(child)) } else { aggregate.AggUtils.planAggregateWithOneDistinct( - groupingExpressions, + normalizedGroupingExpressions, functionsWithDistinct, functionsWithoutDistinct, resultExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 8b7556b..4d762c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -35,20 +35,12 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because - // `groupingExpressions` is not extracted during logical phase. - val normalizedGroupingExpressions = groupingExpressions.map { e => - NormalizeFloatingNumbers.normalize(e) match { - case n: NamedExpression => n - case other => Alias(other, e.name)(exprId = e.exprId) - } - } val useHash = HashAggregateExec.supportsAggregate( aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) if (useHash) { HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -61,7 +53,7 @@ object AggUtils { if (objectHashEnabled && useObjectHash) { ObjectHashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, @@ -70,7 +62,7 @@ object AggUtils { } else { SortAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = normalizedGroupingExpressions, + groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, aggregateAttributes = aggregateAttributes, initialInputBufferOffset = initialInputBufferOffset, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org