This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a68fc3  [SPARK-27816][SQL] make TreeNode tag type safe
1a68fc3 is described below

commit 1a68fc38f0aafb9015c499b3f9f7fbe63739e909
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Thu May 23 11:53:21 2019 -0700

    [SPARK-27816][SQL] make TreeNode tag type safe
    
    ## What changes were proposed in this pull request?
    
    Add type parameter to `TreeNodeTag`.
    
    ## How was this patch tested?
    
    existing tests
    
    Closes #24687 from cloud-fan/tag.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: gatorsmile <gatorsm...@gmail.com>
---
 .../plans/logical/basicLogicalOperators.scala       |  2 +-
 .../apache/spark/sql/catalyst/trees/TreeNode.scala  | 21 ++++++++++++++++-----
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala    | 18 ++++++++++--------
 .../org/apache/spark/sql/execution/SparkPlan.scala  |  5 +++--
 .../spark/sql/execution/SparkStrategies.scala       |  2 +-
 .../execution/LogicalPlanTagInSparkPlanSuite.scala  | 11 +++++------
 6 files changed, 36 insertions(+), 23 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a2a7eb1..4350f91 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1083,7 +1083,7 @@ case class OneRowRelation() extends LeafNode {
   /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not 
support 0-arg ctor. */
   override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
     val newCopy = OneRowRelation()
-    newCopy.tags ++= this.tags
+    newCopy.copyTagsFrom(this)
     newCopy
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a5705d0..cd5dfb7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -74,9 +74,8 @@ object CurrentOrigin {
   }
 }
 
-// The name of the tree node tag. This is preferred over using string 
directly, as we can easily
-// find all the defined tags.
-case class TreeNodeTagName(name: String)
+// A tag of a `TreeNode`, which defines name and type
+case class TreeNodeTag[T](name: String)
 
 // scalastyle:off
 abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
@@ -89,7 +88,19 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
    * A mutable map for holding auxiliary information of this tree node. It 
will be carried over
    * when this node is copied via `makeCopy`, or transformed via 
`transformUp`/`transformDown`.
    */
-  val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
+  private val tags: mutable.Map[TreeNodeTag[_], Any] = mutable.Map.empty
+
+  protected def copyTagsFrom(other: BaseType): Unit = {
+    tags ++= other.tags
+  }
+
+  def setTagValue[T](tag: TreeNodeTag[T], value: T): Unit = {
+    tags(tag) = value
+  }
+
+  def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
+    tags.get(tag).map(_.asInstanceOf[T])
+  }
 
   /**
    * Returns a Seq of the children of this node.
@@ -418,7 +429,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     try {
       CurrentOrigin.withOrigin(origin) {
         val res = defaultCtor.newInstance(allArgs.toArray: 
_*).asInstanceOf[BaseType]
-        res.tags ++= this.tags
+        res.copyTagsFrom(this)
         res
       }
     } catch {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 5cfa84d..744d522 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -622,31 +622,33 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
   }
 
   test("tags will be carried over after copy & transform") {
+    val tag = TreeNodeTag[String]("test")
+
     withClue("makeCopy") {
       val node = Dummy(None)
-      node.tags += TreeNodeTagName("test") -> "a"
+      node.setTagValue(tag, "a")
       val copied = node.makeCopy(Array(Some(Literal(1))))
-      assert(copied.tags(TreeNodeTagName("test")) == "a")
+      assert(copied.getTagValue(tag) == Some("a"))
     }
 
     def checkTransform(
         sameTypeTransform: Expression => Expression,
         differentTypeTransform: Expression => Expression): Unit = {
       val child = Dummy(None)
-      child.tags += TreeNodeTagName("test") -> "child"
+      child.setTagValue(tag, "child")
       val node = Dummy(Some(child))
-      node.tags += TreeNodeTagName("test") -> "parent"
+      node.setTagValue(tag, "parent")
 
       val transformed = sameTypeTransform(node)
       // Both the child and parent keep the tags
-      assert(transformed.tags(TreeNodeTagName("test")) == "parent")
-      assert(transformed.children.head.tags(TreeNodeTagName("test")) == 
"child")
+      assert(transformed.getTagValue(tag) == Some("parent"))
+      assert(transformed.children.head.getTagValue(tag) == Some("child"))
 
       val transformed2 = differentTypeTransform(node)
       // Both the child and parent keep the tags, even if we transform the 
node to a new one of
       // different type.
-      assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
-      assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
+      assert(transformed2.getTagValue(tag) == Some("parent"))
+      assert(transformed2.children.head.getTagValue(tag) == Some("child"))
     }
 
     withClue("transformDown") {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 307a01a5..ddcf61b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -33,15 +33,16 @@ import 
org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => 
GenPredicate, _}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
+import org.apache.spark.sql.catalyst.trees.TreeNodeTag
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.DataType
 
 object SparkPlan {
   // a TreeNode tag in SparkPlan, to carry its original logical plan. The 
planner will add this tag
   // when converting a logical plan to a physical plan.
-  val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
+  val LOGICAL_PLAN_TAG = TreeNodeTag[LogicalPlan]("logical_plan")
 }
 
 /**
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c9db78b..c403149 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -69,7 +69,7 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         case ReturnAnswer(rootPlan) => rootPlan
         case _ => plan
       }
-      p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
+      p.setTagValue(SparkPlan.LOGICAL_PLAN_TAG, logicalPlan)
       p
     }
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
index ca7ced5..b35348b 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -21,7 +21,6 @@ import scala.reflect.ClassTag
 
 import org.apache.spark.sql.TPCDSQuerySuite
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Final}
-import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, 
LocalRelation, LogicalPlan, Range, Sample, Union, Window}
 import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
 import org.apache.spark.sql.execution.columnar.{InMemoryRelation, 
InMemoryTableScanExec}
@@ -81,12 +80,12 @@ class LogicalPlanTagInSparkPlanSuite extends 
TPCDSQuerySuite {
       // The exchange related nodes are created after the planning, they don't 
have corresponding
       // logical plan.
       case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: 
ReusedExchangeExec =>
-        assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
+        assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
 
       // The subquery exec nodes are just wrappers of the actual nodes, they 
don't have
       // corresponding logical plan.
       case _: SubqueryExec | _: ReusedSubqueryExec =>
-        assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
+        assert(plan.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isEmpty)
 
       case _ if isScanPlanTree(plan) =>
         // The strategies for planning scan can remove or add 
FilterExec/ProjectExec nodes,
@@ -120,9 +119,9 @@ class LogicalPlanTagInSparkPlanSuite extends 
TPCDSQuerySuite {
   }
 
   private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
-    assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
-      node.getClass.getSimpleName + " does not have a logical plan link")
-    node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
+    node.getTagValue(SparkPlan.LOGICAL_PLAN_TAG).getOrElse {
+      fail(node.getClass.getSimpleName + " does not have a logical plan link")
+    }
   }
 
   private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: 
SparkPlan): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to