This is an automated email from the ASF dual-hosted git repository. lixiao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0e6601a [SPARK-27747][SQL] add a logical plan link in the physical plan 0e6601a is described below commit 0e6601acdf17c770f880fbc263747779739f4c92 Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Mon May 20 13:42:25 2019 -0700 [SPARK-27747][SQL] add a logical plan link in the physical plan ## What changes were proposed in this pull request? It's pretty useful if we can convert a physical plan back to a logical plan, e.g., in https://github.com/apache/spark/pull/24389 This PR introduces a new feature to `TreeNode`, which allows `TreeNode` to carry some extra information via a mutable map, and keep the information when it's copied. The planner leverages this feature to put the logical plan into the physical plan. ## How was this patch tested? a test suite that runs all TPCDS queries and checks that some common physical plans contain the corresponding logical plans. Closes #24626 from cloud-fan/link. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Peng Bo <bo.peng1...@gmail.com> Signed-off-by: gatorsmile <gatorsm...@gmail.com> --- .../plans/logical/basicLogicalOperators.scala | 6 +- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 25 +++- .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 51 ++++++++ .../org/apache/spark/sql/execution/SparkPlan.scala | 9 +- .../spark/sql/execution/SparkStrategies.scala | 11 ++ .../execution/LogicalPlanTagInSparkPlanSuite.scala | 133 +++++++++++++++++++++ 6 files changed, 228 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1925d45..a2a7eb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1081,7 +1081,11 @@ case class OneRowRelation() extends LeafNode { override def computeStats(): Statistics = Statistics(sizeInBytes = 1) /** [[org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy()]] does not support 0-arg ctor. */ - override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = OneRowRelation() + override def makeCopy(newArgs: Array[AnyRef]): OneRowRelation = { + val newCopy = OneRowRelation() + newCopy.tags ++= this.tags + newCopy + } } /** A logical plan for `dropDuplicates`. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 84ca066..a5705d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.trees import java.util.UUID -import scala.collection.Map +import scala.collection.{mutable, Map} import scala.reflect.ClassTag import org.apache.commons.lang3.ClassUtils @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} -import org.apache.spark.sql.catalyst.util.StringUtils.{PlanStringConcat, StringConcat} +import org.apache.spark.sql.catalyst.util.StringUtils.PlanStringConcat import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -74,6 +74,10 @@ object CurrentOrigin { } } +// The name of the tree node tag. This is preferred over using string directly, as we can easily +// find all the defined tags. +case class TreeNodeTagName(name: String) + // scalastyle:off abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { // scalastyle:on @@ -82,6 +86,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { val origin: Origin = CurrentOrigin.get /** + * A mutable map for holding auxiliary information of this tree node. It will be carried over + * when this node is copied via `makeCopy`, or transformed via `transformUp`/`transformDown`. + */ + val tags: mutable.Map[TreeNodeTagName, Any] = mutable.Map.empty + + /** * Returns a Seq of the children of this node. * Children should not change. Immutability required for containsChild optimization */ @@ -262,6 +272,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { if (this fastEquals afterRule) { mapChildren(_.transformDown(rule)) } else { + // If the transform function replaces this node with a new one, carry over the tags. + afterRule.tags ++= this.tags afterRule.mapChildren(_.transformDown(rule)) } } @@ -275,7 +287,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { */ def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = { val afterRuleOnChildren = mapChildren(_.transformUp(rule)) - if (this fastEquals afterRuleOnChildren) { + val newNode = if (this fastEquals afterRuleOnChildren) { CurrentOrigin.withOrigin(origin) { rule.applyOrElse(this, identity[BaseType]) } @@ -284,6 +296,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { rule.applyOrElse(afterRuleOnChildren, identity[BaseType]) } } + // If the transform function replaces this node with a new one, carry over the tags. + newNode.tags ++= this.tags + newNode } /** @@ -402,7 +417,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { try { CurrentOrigin.withOrigin(origin) { - defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + val res = defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] + res.tags ++= this.tags + res } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index e7ad04f..5cfa84d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -620,4 +620,55 @@ class TreeNodeSuite extends SparkFunSuite with SQLHelper { assert(planString.startsWith("Truncated plan of")) } } + + test("tags will be carried over after copy & transform") { + withClue("makeCopy") { + val node = Dummy(None) + node.tags += TreeNodeTagName("test") -> "a" + val copied = node.makeCopy(Array(Some(Literal(1)))) + assert(copied.tags(TreeNodeTagName("test")) == "a") + } + + def checkTransform( + sameTypeTransform: Expression => Expression, + differentTypeTransform: Expression => Expression): Unit = { + val child = Dummy(None) + child.tags += TreeNodeTagName("test") -> "child" + val node = Dummy(Some(child)) + node.tags += TreeNodeTagName("test") -> "parent" + + val transformed = sameTypeTransform(node) + // Both the child and parent keep the tags + assert(transformed.tags(TreeNodeTagName("test")) == "parent") + assert(transformed.children.head.tags(TreeNodeTagName("test")) == "child") + + val transformed2 = differentTypeTransform(node) + // Both the child and parent keep the tags, even if we transform the node to a new one of + // different type. + assert(transformed2.tags(TreeNodeTagName("test")) == "parent") + assert(transformed2.children.head.tags.contains(TreeNodeTagName("test"))) + } + + withClue("transformDown") { + checkTransform( + sameTypeTransform = _ transformDown { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformDown { + case Dummy(None) => Literal(1) + + }) + } + + withClue("transformUp") { + checkTransform( + sameTypeTransform = _ transformUp { + case Dummy(None) => Dummy(Some(Literal(1))) + }, + differentTypeTransform = _ transformUp { + case Dummy(None) => Literal(1) + + }) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index a89ccca..307a01a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext import org.codehaus.commons.compiler.CompileException import org.codehaus.janino.InternalCompilerException @@ -35,9 +34,15 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => GenPredicate, _} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.trees.TreeNodeTagName import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.DataType -import org.apache.spark.util.ThreadUtils + +object SparkPlan { + // a TreeNode tag in SparkPlan, to carry its original logical plan. The planner will add this tag + // when converting a logical plan to a physical plan. + val LOGICAL_PLAN_TAG_NAME = TreeNodeTagName("logical_plan") +} /** * The base class for physical operators. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 831fc73..c9db78b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -63,6 +63,17 @@ case class PlanLater(plan: LogicalPlan) extends LeafExecNode { abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => + override def plan(plan: LogicalPlan): Iterator[SparkPlan] = { + super.plan(plan).map { p => + val logicalPlan = plan match { + case ReturnAnswer(rootPlan) => rootPlan + case _ => plan + } + p.tags += SparkPlan.LOGICAL_PLAN_TAG_NAME -> logicalPlan + p + } + } + /** * Plans special cases of limit operators. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala new file mode 100644 index 0000000..ca7ced5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/LogicalPlanTagInSparkPlanSuite.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.reflect.ClassTag + +import org.apache.spark.sql.TPCDSQuerySuite +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Final} +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, LocalRelation, LogicalPlan, Range, Sample, Union, Window} +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.execution.window.WindowExec + +class LogicalPlanTagInSparkPlanSuite extends TPCDSQuerySuite { + + override protected def checkGeneratedCode(plan: SparkPlan): Unit = { + super.checkGeneratedCode(plan) + checkLogicalPlanTag(plan) + } + + private def isFinalAgg(aggExprs: Seq[AggregateExpression]): Boolean = { + // TODO: aggregate node without aggregate expressions can also be a final aggregate, but + // currently the aggregate node doesn't have a final/partial flag. + aggExprs.nonEmpty && aggExprs.forall(ae => ae.mode == Complete || ae.mode == Final) + } + + // A scan plan tree is a plan tree that has a leaf node under zero or more Project/Filter nodes. + private def isScanPlanTree(plan: SparkPlan): Boolean = plan match { + case p: ProjectExec => isScanPlanTree(p.child) + case f: FilterExec => isScanPlanTree(f.child) + case _: LeafExecNode => true + case _ => false + } + + private def checkLogicalPlanTag(plan: SparkPlan): Unit = { + plan match { + case _: HashJoin | _: BroadcastNestedLoopJoinExec | _: CartesianProductExec + | _: ShuffledHashJoinExec | _: SortMergeJoinExec => + assertLogicalPlanType[Join](plan) + + // There is no corresponding logical plan for the physical partial aggregate. + case agg: HashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: ObjectHashAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + case agg: SortAggregateExec if isFinalAgg(agg.aggregateExpressions) => + assertLogicalPlanType[Aggregate](plan) + + case _: WindowExec => + assertLogicalPlanType[Window](plan) + + case _: UnionExec => + assertLogicalPlanType[Union](plan) + + case _: SampleExec => + assertLogicalPlanType[Sample](plan) + + case _: GenerateExec => + assertLogicalPlanType[Generate](plan) + + // The exchange related nodes are created after the planning, they don't have corresponding + // logical plan. + case _: ShuffleExchangeExec | _: BroadcastExchangeExec | _: ReusedExchangeExec => + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + + // The subquery exec nodes are just wrappers of the actual nodes, they don't have + // corresponding logical plan. + case _: SubqueryExec | _: ReusedSubqueryExec => + assert(!plan.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME)) + + case _ if isScanPlanTree(plan) => + // The strategies for planning scan can remove or add FilterExec/ProjectExec nodes, + // so it's not simple to check. Instead, we only check that the origin LogicalPlan + // contains the corresponding leaf node of the SparkPlan. + // a strategy might remove the filter if it's totally pushed down, e.g.: + // logical = Project(Filter(Scan A)) + // physical = ProjectExec(ScanExec A) + // we only check that leaf modes match between logical and physical plan. + val logicalLeaves = getLogicalPlan(plan).collectLeaves() + val physicalLeaves = plan.collectLeaves() + assert(logicalLeaves.length == 1) + assert(physicalLeaves.length == 1) + physicalLeaves.head match { + case _: RangeExec => logicalLeaves.head.isInstanceOf[Range] + case _: DataSourceScanExec => logicalLeaves.head.isInstanceOf[LogicalRelation] + case _: InMemoryTableScanExec => logicalLeaves.head.isInstanceOf[InMemoryRelation] + case _: LocalTableScanExec => logicalLeaves.head.isInstanceOf[LocalRelation] + case _: ExternalRDDScanExec[_] => logicalLeaves.head.isInstanceOf[ExternalRDD[_]] + case _: BatchScanExec => logicalLeaves.head.isInstanceOf[DataSourceV2Relation] + case _ => + } + // Do not need to check the children recursively. + return + + case _ => + } + + plan.children.foreach(checkLogicalPlanTag) + plan.subqueries.foreach(checkLogicalPlanTag) + } + + private def getLogicalPlan(node: SparkPlan): LogicalPlan = { + assert(node.tags.contains(SparkPlan.LOGICAL_PLAN_TAG_NAME), + node.getClass.getSimpleName + " does not have a logical plan link") + node.tags(SparkPlan.LOGICAL_PLAN_TAG_NAME).asInstanceOf[LogicalPlan] + } + + private def assertLogicalPlanType[T <: LogicalPlan : ClassTag](node: SparkPlan): Unit = { + val logicalPlan = getLogicalPlan(node) + val expectedCls = implicitly[ClassTag[T]].runtimeClass + assert(expectedCls == logicalPlan.getClass) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org