Repository: spark Updated Branches: refs/heads/master a8af4da12 -> 53e5251bb
[SPARK-22675][SQL] Refactoring PropagateTypes in TypeCoercion ## What changes were proposed in this pull request? PropagateTypes are called twice in TypeCoercion. We do not need to call it twice. Instead, we should call it after each change on the types. ## How was this patch tested? The existing tests Author: gatorsmile <gatorsm...@gmail.com> Closes #19874 from gatorsmile/deduplicatePropagateTypes. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/53e5251b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/53e5251b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/53e5251b Branch: refs/heads/master Commit: 53e5251bb36cfa36ffa9058887a9944f87f26879 Parents: a8af4da Author: gatorsmile <gatorsm...@gmail.com> Authored: Tue Dec 5 20:43:02 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Tue Dec 5 20:43:02 2017 +0800 ---------------------------------------------------------------------- .../catalyst/analysis/DecimalPrecision.scala | 4 +- .../sql/catalyst/analysis/TypeCoercion.scala | 116 ++++++++++--------- .../sql/catalyst/expressions/Canonicalize.scala | 2 +- .../apache/spark/sql/SQLQueryTestSuite.scala | 2 +- .../org/apache/spark/sql/TPCDSQuerySuite.scala | 3 + .../hive/execution/HiveCompatibilitySuite.scala | 2 +- 6 files changed, 71 insertions(+), 58 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala index fd2ac78..070bc54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala @@ -58,7 +58,7 @@ import org.apache.spark.sql.types._ * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE */ // scalastyle:on -object DecimalPrecision extends Rule[LogicalPlan] { +object DecimalPrecision extends TypeCoercionRule { import scala.math.{max, min} private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType @@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] { PromotePrecision(Cast(e, dataType)) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // fix decimal precision for expressions case q => q.transformExpressionsUp( decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal)) http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 28be955..1ee2f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -22,6 +22,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -45,8 +46,7 @@ import org.apache.spark.sql.types._ object TypeCoercion { val typeCoercionRules = - PropagateTypes :: - InConversion :: + InConversion :: WidenSetOperationTypes :: PromoteStrings :: DecimalPrecision :: @@ -56,7 +56,6 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: - PropagateTypes :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -221,38 +220,6 @@ object TypeCoercion { exprs.map(_.dataType).distinct.length == 1 /** - * Applies any changes to [[AttributeReference]] data types that are made by other rules to - * instances higher in the query tree. - */ - object PropagateTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - - // No propagation required for leaf nodes. - case q: LogicalPlan if q.children.isEmpty => q - - // Don't propagate types from unresolved children. - case q: LogicalPlan if !q.childrenResolved => q - - case q: LogicalPlan => - val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap - q transformExpressions { - case a: AttributeReference => - inputMap.get(a.exprId) match { - // This can happen when an Attribute reference is born in a non-leaf node, for - // example due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}") - newType - } - } - } - } - - /** * Widens numeric types and converts strings to numbers when appropriate. * * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, by Tom White @@ -345,7 +312,7 @@ object TypeCoercion { /** * Promotes strings that appear in arithmetic expressions. */ - object PromoteStrings extends Rule[LogicalPlan] { + object PromoteStrings extends TypeCoercionRule { private def castExpr(expr: Expression, targetType: DataType): Expression = { (expr.dataType, targetType) match { case (NullType, dt) => Literal.create(null, targetType) @@ -354,7 +321,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -403,7 +370,7 @@ object TypeCoercion { * operator type is found the original expression will be returned and an * Analysis Exception will be raised at the type checking phase. */ - object InConversion extends Rule[LogicalPlan] { + object InConversion extends TypeCoercionRule { private def flattenExpr(expr: Expression): Seq[Expression] = { expr match { // Multi columns in IN clause is represented as a CreateNamedStruct. @@ -413,7 +380,7 @@ object TypeCoercion { } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -512,8 +479,8 @@ object TypeCoercion { /** * This ensure that the types for various functions are as expected. */ - object FunctionArgumentConversion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object FunctionArgumentConversion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -602,8 +569,8 @@ object TypeCoercion { * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. */ - object Division extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object Division extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. case e if !e.childrenResolved => e @@ -624,8 +591,8 @@ object TypeCoercion { /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ - object CaseWhenCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object CaseWhenCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => val maybeCommonType = findWiderCommonType(c.valueTypes) maybeCommonType.map { commonType => @@ -654,8 +621,8 @@ object TypeCoercion { /** * Coerces the type of different branches of If statement to a common type. */ - object IfCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object IfCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => @@ -674,8 +641,8 @@ object TypeCoercion { /** * Coerces NullTypes in the Stack expression to the column types of the corresponding positions. */ - object StackCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + object StackCoercion extends TypeCoercionRule { + override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows => Stack(children.zipWithIndex.map { // The first child is the number of rows for stack. @@ -711,8 +678,8 @@ object TypeCoercion { /** * Casts types according to the expected input types for [[Expression]]s. */ - object ImplicitTypeCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object ImplicitTypeCasts extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e @@ -828,8 +795,8 @@ object TypeCoercion { /** * Cast WindowFrame boundaries to the type they operate upon. */ - object WindowFrameCoercion extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + object WindowFrameCoercion extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case s @ WindowSpecDefinition(_, Seq(order), SpecifiedWindowFrame(RangeFrame, lower, upper)) if order.resolved => s.copy(frameSpecification = SpecifiedWindowFrame( @@ -850,3 +817,46 @@ object TypeCoercion { } } } + +trait TypeCoercionRule extends Rule[LogicalPlan] with Logging { + /** + * Applies any changes to [[AttributeReference]] data types that are made by the transform method + * to instances higher in the query tree. + */ + def apply(plan: LogicalPlan): LogicalPlan = { + val newPlan = coerceTypes(plan) + if (plan.fastEquals(newPlan)) { + plan + } else { + propagateTypes(newPlan) + } + } + + protected def coerceTypes(plan: LogicalPlan): LogicalPlan + + private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan transformUp { + // No propagation required for leaf nodes. + case q: LogicalPlan if q.children.isEmpty => q + + // Don't propagate types from unresolved children. + case q: LogicalPlan if !q.childrenResolved => q + + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when an Attribute reference is born in a non-leaf node, for + // example due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug( + s"Promoting $a from ${a.dataType} to ${newType.dataType} in ${q.simpleString}") + newType + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala index 65e497a..d848ba1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala @@ -31,7 +31,7 @@ package org.apache.spark.sql.catalyst.expressions * - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`. * - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by `hashCode`. */ -object Canonicalize extends { +object Canonicalize { def execute(e: Expression): Expression = { expressionReorder(ignoreNamesTypes(e)) } http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index e3901af..cac3e12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -300,7 +300,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { Locale.setDefault(originalLocale) // For debugging dump some statistics about how much time was spent in various optimizer rules - logWarning(RuleExecutor.dumpTimeSpent()) + logInfo(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala index e47d4b0..dbea036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.resourceToString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -39,6 +40,8 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfte */ protected override def afterAll(): Unit = { try { + // For debugging dump some statistics about how much time was spent in various optimizer rules + logInfo(RuleExecutor.dumpTimeSpent()) spark.sessionState.catalog.reset() } finally { super.afterAll() http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 45791c6..def70a5 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, originalSessionLocalTimeZone) // For debugging dump some statistics about how much time was spent in various optimizer rules - logWarning(RuleExecutor.dumpTimeSpent()) + logInfo(RuleExecutor.dumpTimeSpent()) } finally { super.afterAll() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org