This is an automated email from the ASF dual-hosted git repository.

lixiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e6601a  [SPARK-27747][SQL] add a logical plan link in the physical 
plan
0e6601a is described below

commit 0e6601acdf17c770f880fbc263747779739f4c92
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Mon May 20 13:42:25 2019 -0700

    [SPARK-27747][SQL] add a logical plan link in the physical plan
    
    ## What changes were proposed in this pull request?
    
    It's pretty useful if we can convert a physical plan back to a logical 
plan, e.g., in https://github.com/apache/spark/pull/24389
    
    This PR introduces a new feature to `TreeNode`, which allows `TreeNode` to 
carry some extra information via a mutable map, and keep the information when 
it's copied.
    
    The planner leverages this feature to put the logical plan into the 
physical plan.
    
    ## How was this patch tested?
    
    a test suite that runs all TPCDS queries and checks that some common 
physical plans contain the corresponding logical plans.
    
    Closes #24626 from cloud-fan/link.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Peng Bo <bo.peng1...@gmail.com>
    Signed-off-by: gatorsmile <gatorsm...@gmail.com>
---
 .../plans/logical/basicLogicalOperators.scala      |   6 +-
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  25 +++-
 .../spark/sql/catalyst/trees/TreeNodeSuite.scala   |  51 ++++++++
 .../org/apache/spark/sql/execution/SparkPlan.scala |   9 +-
 .../spark/sql/execution/SparkStrategies.scala      |  11 ++
 .../execution/LogicalPlanTagInSparkPlanSuite.scala | 133 +++++++++++++++++++++
 6 files changed, 228 insertions(+), 7 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 1925d45..a2a7eb1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1081,7 +1081,11 @@ case class OneRowRelation() extends LeafNode {
   override def computeStats(): Statistics = Statistics(sizeInBytes = 1)
 
   /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not 
support 0-arg ctor. */
-  override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = 
OneRowRelation()
+  override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = {
+    val newCopy = OneRowRelation()
+    newCopy.tags ++= this.tags
+    newCopy
+  }
 }
 
 /** A logical plan for `dropDuplicates`. */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 84ca066..a5705d0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees
 
 import java.util.UUID
 
-import scala.collection.Map
+import scala.collection.{mutable, Map}
 import scala.reflect.ClassTag
 
 import org.apache.commons.lang3.ClassUtils
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.JoinType
 import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, 
Partitioning}
-import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, 
StringConcat}
+import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -74,6 +74,10 @@ object CurrentOrigin {
   }
 }
 
+// The name of the tree node tag. This is preferred over using string 
directly, as we can easily
+// find all the defined tags.
+case class TreeNodeTagName(name: String)
+
 // scalastyle:off
 abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
 // scalastyle:on
@@ -82,6 +86,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
   val origin: Origin = CurrentOrigin.get
 
   /**
+   * A mutable map for holding auxiliary information of this tree node. It 
will be carried over
+   * when this node is copied via `makeCopy`, or transformed via 
`transformUp`/`transformDown`.
+   */
+  val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty
+
+  /**
    * Returns a Seq of the children of this node.
    * Children should not change. Immutability required for containsChild 
optimization
    */
@@ -262,6 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     if (this fastEquals afterRule) {
       mapChildren(_.transformDown(rule))
     } else {
+      // If the transform function replaces this node with a new one, carry 
over the tags.
+      afterRule.tags ++= this.tags
       afterRule.mapChildren(_.transformDown(rule))
     }
   }
@@ -275,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
    */
   def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
     val afterRuleOnChildren = mapChildren(_.transformUp(rule))
-    if (this fastEquals afterRuleOnChildren) {
+    val newNode = if (this fastEquals afterRuleOnChildren) {
       CurrentOrigin.withOrigin(origin) {
         rule.applyOrElse(this, identity[BaseType])
       }
@@ -284,6 +296,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
         rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
       }
     }
+    // If the transform function replaces this node with a new one, carry over 
the tags.
+    newNode.tags ++= this.tags
+    newNode
   }
 
   /**
@@ -402,7 +417,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
 
     try {
       CurrentOrigin.withOrigin(origin) {
-        defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType]
+        val res = defaultCtor.newInstance(allArgs.toArray: 
_*).asInstanceOf[BaseType]
+        res.tags ++= this.tags
+        res
       }
     } catch {
       case e: java.lang.IllegalArgumentException =>
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index e7ad04f..5cfa84d 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -620,4 +620,55 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper {
       assert(planString.startsWith("Truncated plan of"))
     }
   }
+
+  test("tags will be carried over after copy & transform") {
+    withClue("makeCopy") {
+      val node = Dummy(None)
+      node.tags += TreeNodeTagName("test") -> "a"
+      val copied = node.makeCopy(Array(Some(Literal(1))))
+      assert(copied.tags(TreeNodeTagName("test")) == "a")
+    }
+
+    def checkTransform(
+        sameTypeTransform: Expression => Expression,
+        differentTypeTransform: Expression => Expression): Unit = {
+      val child = Dummy(None)
+      child.tags += TreeNodeTagName("test") -> "child"
+      val node = Dummy(Some(child))
+      node.tags += TreeNodeTagName("test") -> "parent"
+
+      val transformed = sameTypeTransform(node)
+      // Both the child and parent keep the tags
+      assert(transformed.tags(TreeNodeTagName("test")) == "parent")
+      assert(transformed.children.head.tags(TreeNodeTagName("test")) == 
"child")
+
+      val transformed2 = differentTypeTransform(node)
+      // Both the child and parent keep the tags, even if we transform the 
node to a new one of
+      // different type.
+      assert(transformed2.tags(TreeNodeTagName("test")) == "parent")
+      assert(transformed2.children.head.tags.contains(TreeNodeTagName("test")))
+    }
+
+    withClue("transformDown") {
+      checkTransform(
+        sameTypeTransform = _ transformDown {
+          case Dummy(None) => Dummy(Some(Literal(1)))
+        },
+        differentTypeTransform = _ transformDown {
+          case Dummy(None) => Literal(1)
+
+        })
+    }
+
+    withClue("transformUp") {
+      checkTransform(
+        sameTypeTransform = _ transformUp {
+          case Dummy(None) => Dummy(Some(Literal(1)))
+        },
+        differentTypeTransform = _ transformUp {
+          case Dummy(None) => Literal(1)
+
+        })
+    }
+  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a89ccca..307a01a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, 
DataOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.ExecutionContext
 
 import org.codehaus.commons.compiler.CompileException
 import org.codehaus.janino.InternalCompilerException
@@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => 
GenPredicate, _}
 import org.apache.spark.sql.catalyst.plans.QueryPlan
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.trees.TreeNodeTagName
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.ThreadUtils
+
+object SparkPlan {
+  // a TreeNode tag in SparkPlan, to carry its original logical plan. The 
planner will add this tag
+  // when converting a logical plan to a physical plan.
+  val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan")
+}
 
 /**
  * The base class for physical operators.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 831fc73..c9db78b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode 
{
 abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
   self: SparkPlanner =>
 
+  override def plan(plan: LogicalPlan): Iterator[SparkPlan] = {
+    super.plan(plan).map { p =>
+      val logicalPlan = plan match {
+        case ReturnAnswer(rootPlan) => rootPlan
+        case _ => plan
+      }
+      p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan
+      p
+    }
+  }
+
   /**
    * Plans special cases of limit operators.
    */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
new file mode 100644
index 0000000..ca7ced5
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.sql.TPCDSQuerySuite
+import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete, Final}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, 
LocalRelation, LogicalPlan, Range, Sample, Union, Window}
+import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, 
ObjectHashAggregateExec, SortAggregateExec}
+import org.apache.spark.sql.execution.columnar.{InMemoryRelation, 
InMemoryTableScanExec}
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, 
DataSourceV2Relation}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, 
ReusedExchangeExec, ShuffleExchangeExec}
+import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.window.WindowExec
+
+class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite {
+
+  override protected def checkGeneratedCode(plan: SparkPlan): Unit = {
+    super.checkGeneratedCode(plan)
+    checkLogicalPlanTag(plan)
+  }
+
+  private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = {
+    // TODO: aggregate node without aggregate expressions can also be a final 
aggregate, but
+    // currently the aggregate node doesn't have a final/partial flag.
+    aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode 
== Final)
+  }
+
+  // A scan plan tree is a plan tree that has a leaf node under zero or more 
Project/Filter nodes.
+  private def isScanPlanTree(plan: SparkPlan): Boolean = plan match {
+    case p: ProjectExec => isScanPlanTree(p.child)
+    case f: FilterExec => isScanPlanTree(f.child)
+    case _: LeafExecNode => true
+    case _ => false
+  }
+
+  private def checkLogicalPlanTag(plan: SparkPlan): Unit = {
+    plan match {
+      case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: 
CartesianProductExec
+           | _: ShuffledHashJoinExec | _: SortMergeJoinExec =>
+        assertLogicalPlanType[Join](plan)
+
+      // There is no corresponding logical plan for the physical partial 
aggregate.
+      case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
+        assertLogicalPlanType[Aggregate](plan)
+      case agg: ObjectHashAggregateExec if 
isFinalAgg(agg.aggregateExpressions) =>
+        assertLogicalPlanType[Aggregate](plan)
+      case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) =>
+        assertLogicalPlanType[Aggregate](plan)
+
+      case _: WindowExec =>
+        assertLogicalPlanType[Window](plan)
+
+      case _: UnionExec =>
+        assertLogicalPlanType[Union](plan)
+
+      case _: SampleExec =>
+        assertLogicalPlanType[Sample](plan)
+
+      case _: GenerateExec =>
+        assertLogicalPlanType[Generate](plan)
+
+      // The exchange related nodes are created after the planning, they don't 
have corresponding
+      // logical plan.
+      case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: 
ReusedExchangeExec =>
+        assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
+
+      // The subquery exec nodes are just wrappers of the actual nodes, they 
don't have
+      // corresponding logical plan.
+      case _: SubqueryExec | _: ReusedSubqueryExec =>
+        assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME))
+
+      case _ if isScanPlanTree(plan) =>
+        // The strategies for planning scan can remove or add 
FilterExec/ProjectExec nodes,
+        // so it's not simple to check. Instead, we only check that the origin 
LogicalPlan
+        // contains the corresponding leaf node of the SparkPlan.
+        // a strategy might remove the filter if it's totally pushed down, 
e.g.:
+        //   logical = Project(Filter(Scan A))
+        //   physical = ProjectExec(ScanExec A)
+        // we only check that leaf modes match between logical and physical 
plan.
+        val logicalLeaves = getLogicalPlan(plan).collectLeaves()
+        val physicalLeaves = plan.collectLeaves()
+        assert(logicalLeaves.length == 1)
+        assert(physicalLeaves.length == 1)
+        physicalLeaves.head match {
+          case _: RangeExec => logicalLeaves.head.isInstanceOf[Range]
+          case _: DataSourceScanExec => 
logicalLeaves.head.isInstanceOf[LogicalRelation]
+          case _: InMemoryTableScanExec => 
logicalLeaves.head.isInstanceOf[InMemoryRelation]
+          case _: LocalTableScanExec => 
logicalLeaves.head.isInstanceOf[LocalRelation]
+          case _: ExternalRDDScanExec[_] => 
logicalLeaves.head.isInstanceOf[ExternalRDD[_]]
+          case _: BatchScanExec => 
logicalLeaves.head.isInstanceOf[DataSourceV2Relation]
+          case _ =>
+        }
+        // Do not need to check the children recursively.
+        return
+
+      case _ =>
+    }
+
+    plan.children.foreach(checkLogicalPlanTag)
+    plan.subqueries.foreach(checkLogicalPlanTag)
+  }
+
+  private def getLogicalPlan(node: SparkPlan): LogicalPlan = {
+    assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME),
+      node.getClass.getSimpleName + " does not have a logical plan link")
+    node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan]
+  }
+
+  private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: 
SparkPlan): Unit = {
+    val logicalPlan = getLogicalPlan(node)
+    val expectedCls = implicitly[ClassTag[T]].runtimeClass
+    assert(expectedCls == logicalPlan.getClass)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to