Repository: spark
Updated Branches:
  refs/heads/master a8af4da12 -> 53e5251bb


[SPARK-22675][SQL] Refactoring PropagateTypes in TypeCoercion

## What changes were proposed in this pull request?
PropagateTypes are called twice in TypeCoercion. We do not need to call it 
twice. Instead, we should call it after each change on the types.

## How was this patch tested?
The existing tests

Author: gatorsmile <gatorsm...@gmail.com>

Closes #19874 from gatorsmile/deduplicatePropagateTypes.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/53e5251b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/53e5251b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/53e5251b

Branch: refs/heads/master
Commit: 53e5251bb36cfa36ffa9058887a9944f87f26879
Parents: a8af4da
Author: gatorsmile <gatorsm...@gmail.com>
Authored: Tue Dec 5 20:43:02 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Tue Dec 5 20:43:02 2017 +0800

----------------------------------------------------------------------
 .../catalyst/analysis/DecimalPrecision.scala    |   4 +-
 .../sql/catalyst/analysis/TypeCoercion.scala    | 116 ++++++++++---------
 .../sql/catalyst/expressions/Canonicalize.scala |   2 +-
 .../apache/spark/sql/SQLQueryTestSuite.scala    |   2 +-
 .../org/apache/spark/sql/TPCDSQuerySuite.scala  |   3 +
 .../hive/execution/HiveCompatibilitySuite.scala |   2 +-
 6 files changed, 71 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
index fd2ac78..070bc54 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecision.scala
@@ -58,7 +58,7 @@ import org.apache.spark.sql.types._
  * - FLOAT and DOUBLE cause fixed-length decimals to turn into DOUBLE
  */
 // scalastyle:on
-object DecimalPrecision extends Rule[LogicalPlan] {
+object DecimalPrecision extends TypeCoercionRule {
   import scala.math.{max, min}
 
   private def isFloat(t: DataType): Boolean = t == FloatType || t == DoubleType
@@ -78,7 +78,7 @@ object DecimalPrecision extends Rule[LogicalPlan] {
     PromotePrecision(Cast(e, dataType))
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+  override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveOperators {
     // fix decimal precision for expressions
     case q => q.transformExpressionsUp(
       
decimalAndDecimal.orElse(integralAndDecimalLiteral).orElse(nondecimalAndDecimal))

http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 28be955..1ee2f6e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -22,6 +22,7 @@ import javax.annotation.Nullable
 import scala.annotation.tailrec
 import scala.collection.mutable
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -45,8 +46,7 @@ import org.apache.spark.sql.types._
 object TypeCoercion {
 
   val typeCoercionRules =
-    PropagateTypes ::
-      InConversion ::
+    InConversion ::
       WidenSetOperationTypes ::
       PromoteStrings ::
       DecimalPrecision ::
@@ -56,7 +56,6 @@ object TypeCoercion {
       IfCoercion ::
       StackCoercion ::
       Division ::
-      PropagateTypes ::
       ImplicitTypeCasts ::
       DateTimeOperations ::
       WindowFrameCoercion ::
@@ -221,38 +220,6 @@ object TypeCoercion {
     exprs.map(_.dataType).distinct.length == 1
 
   /**
-   * Applies any changes to [[AttributeReference]] data types that are made by 
other rules to
-   * instances higher in the query tree.
-   */
-  object PropagateTypes extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
-
-      // No propagation required for leaf nodes.
-      case q: LogicalPlan if q.children.isEmpty => q
-
-      // Don't propagate types from unresolved children.
-      case q: LogicalPlan if !q.childrenResolved => q
-
-      case q: LogicalPlan =>
-        val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
-        q transformExpressions {
-          case a: AttributeReference =>
-            inputMap.get(a.exprId) match {
-              // This can happen when an Attribute reference is born in a 
non-leaf node, for
-              // example due to a call to an external script like in the 
Transform operator.
-              // TODO: Perhaps those should actually be aliases?
-              case None => a
-              // Leave the same if the dataTypes match.
-              case Some(newType) if a.dataType == newType.dataType => a
-              case Some(newType) =>
-                logDebug(s"Promoting $a to $newType in ${q.simpleString}")
-                newType
-            }
-        }
-    }
-  }
-
-  /**
    * Widens numeric types and converts strings to numbers when appropriate.
    *
    * Loosely based on rules from "Hadoop: The Definitive Guide" 2nd edition, 
by Tom White
@@ -345,7 +312,7 @@ object TypeCoercion {
   /**
    * Promotes strings that appear in arithmetic expressions.
    */
-  object PromoteStrings extends Rule[LogicalPlan] {
+  object PromoteStrings extends TypeCoercionRule {
     private def castExpr(expr: Expression, targetType: DataType): Expression = 
{
       (expr.dataType, targetType) match {
         case (NullType, dt) => Literal.create(null, targetType)
@@ -354,7 +321,7 @@ object TypeCoercion {
       }
     }
 
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
@@ -403,7 +370,7 @@ object TypeCoercion {
    *    operator type is found the original expression will be returned and an
    *    Analysis Exception will be raised at the type checking phase.
    */
-  object InConversion extends Rule[LogicalPlan] {
+  object InConversion extends TypeCoercionRule {
     private def flattenExpr(expr: Expression): Seq[Expression] = {
       expr match {
         // Multi columns in IN clause is represented as a CreateNamedStruct.
@@ -413,7 +380,7 @@ object TypeCoercion {
       }
     }
 
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
@@ -512,8 +479,8 @@ object TypeCoercion {
   /**
    * This ensure that the types for various functions are as expected.
    */
-  object FunctionArgumentConversion extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object FunctionArgumentConversion extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
@@ -602,8 +569,8 @@ object TypeCoercion {
    * Hive only performs integral division with the DIV operator. The arguments 
to / are always
    * converted to fractional types.
    */
-  object Division extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object Division extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       // Skip nodes who has not been resolved yet,
       // as this is an extra rule which should be applied at last.
       case e if !e.childrenResolved => e
@@ -624,8 +591,8 @@ object TypeCoercion {
   /**
    * Coerces the type of different branches of a CASE WHEN statement to a 
common type.
    */
-  object CaseWhenCoercion extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object CaseWhenCoercion extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
         val maybeCommonType = findWiderCommonType(c.valueTypes)
         maybeCommonType.map { commonType =>
@@ -654,8 +621,8 @@ object TypeCoercion {
   /**
    * Coerces the type of different branches of If statement to a common type.
    */
-  object IfCoercion extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object IfCoercion extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       case e if !e.childrenResolved => e
       // Find tightest common type for If, if the true value and false value 
have different types.
       case i @ If(pred, left, right) if left.dataType != right.dataType =>
@@ -674,8 +641,8 @@ object TypeCoercion {
   /**
    * Coerces NullTypes in the Stack expression to the column types of the 
corresponding positions.
    */
-  object StackCoercion extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+  object StackCoercion extends TypeCoercionRule {
+    override def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
transformAllExpressions {
       case s @ Stack(children) if s.childrenResolved && s.hasFoldableNumRows =>
         Stack(children.zipWithIndex.map {
           // The first child is the number of rows for stack.
@@ -711,8 +678,8 @@ object TypeCoercion {
   /**
    * Casts types according to the expected input types for [[Expression]]s.
    */
-  object ImplicitTypeCasts extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object ImplicitTypeCasts extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
@@ -828,8 +795,8 @@ object TypeCoercion {
   /**
    * Cast WindowFrame boundaries to the type they operate upon.
    */
-  object WindowFrameCoercion extends Rule[LogicalPlan] {
-    def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  object WindowFrameCoercion extends TypeCoercionRule {
+    override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan 
resolveExpressions {
       case s @ WindowSpecDefinition(_, Seq(order), 
SpecifiedWindowFrame(RangeFrame, lower, upper))
           if order.resolved =>
         s.copy(frameSpecification = SpecifiedWindowFrame(
@@ -850,3 +817,46 @@ object TypeCoercion {
     }
   }
 }
+
+trait TypeCoercionRule extends Rule[LogicalPlan] with Logging {
+  /**
+   * Applies any changes to [[AttributeReference]] data types that are made by 
the transform method
+   * to instances higher in the query tree.
+   */
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    val newPlan = coerceTypes(plan)
+    if (plan.fastEquals(newPlan)) {
+      plan
+    } else {
+      propagateTypes(newPlan)
+    }
+  }
+
+  protected def coerceTypes(plan: LogicalPlan): LogicalPlan
+
+  private def propagateTypes(plan: LogicalPlan): LogicalPlan = plan 
transformUp {
+    // No propagation required for leaf nodes.
+    case q: LogicalPlan if q.children.isEmpty => q
+
+    // Don't propagate types from unresolved children.
+    case q: LogicalPlan if !q.childrenResolved => q
+
+    case q: LogicalPlan =>
+      val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap
+      q transformExpressions {
+        case a: AttributeReference =>
+          inputMap.get(a.exprId) match {
+            // This can happen when an Attribute reference is born in a 
non-leaf node, for
+            // example due to a call to an external script like in the 
Transform operator.
+            // TODO: Perhaps those should actually be aliases?
+            case None => a
+            // Leave the same if the dataTypes match.
+            case Some(newType) if a.dataType == newType.dataType => a
+            case Some(newType) =>
+              logDebug(
+                s"Promoting $a from ${a.dataType} to ${newType.dataType} in 
${q.simpleString}")
+              newType
+          }
+      }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
index 65e497a..d848ba1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala
@@ -31,7 +31,7 @@ package org.apache.spark.sql.catalyst.expressions
  *  - [[EqualTo]] and [[EqualNullSafe]] are reordered by `hashCode`.
  *  - Other comparisons ([[GreaterThan]], [[LessThan]]) are reversed by 
`hashCode`.
  */
-object Canonicalize extends {
+object Canonicalize {
   def execute(e: Expression): Expression = {
     expressionReorder(ignoreNamesTypes(e))
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index e3901af..cac3e12 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -300,7 +300,7 @@ class SQLQueryTestSuite extends QueryTest with 
SharedSQLContext {
       Locale.setDefault(originalLocale)
 
       // For debugging dump some statistics about how much time was spent in 
various optimizer rules
-      logWarning(RuleExecutor.dumpTimeSpent())
+      logInfo(RuleExecutor.dumpTimeSpent())
     } finally {
       super.afterAll()
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
index e47d4b0..dbea036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.catalyst.util.resourceToString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSQLContext
@@ -39,6 +40,8 @@ class TPCDSQuerySuite extends QueryTest with SharedSQLContext 
with BeforeAndAfte
    */
   protected override def afterAll(): Unit = {
     try {
+      // For debugging dump some statistics about how much time was spent in 
various optimizer rules
+      logInfo(RuleExecutor.dumpTimeSpent())
       spark.sessionState.catalog.reset()
     } finally {
       super.afterAll()

http://git-wip-us.apache.org/repos/asf/spark/blob/53e5251b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 45791c6..def70a5 100644
--- 
a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ 
b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -76,7 +76,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with 
BeforeAndAfter {
       TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, 
originalSessionLocalTimeZone)
 
       // For debugging dump some statistics about how much time was spent in 
various optimizer rules
-      logWarning(RuleExecutor.dumpTimeSpent())
+      logInfo(RuleExecutor.dumpTimeSpent())
     } finally {
       super.afterAll()
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to