Repository: spark Updated Branches: refs/heads/master 7703b46d2 -> 299d297e2
[SPARK-24500][SQL] Make sure streams are materialized during Tree transforms. ## What changes were proposed in this pull request? If you construct catalyst trees using `scala.collection.immutable.Stream` you can run into situations where valid transformations do not seem to have any effect. There are two causes for this behavior: - `Stream` is evaluated lazily. Note that default implementation will generally only evaluate a function for the first element (this makes testing a bit tricky). - `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. Mapping over a stream is lazy and does not need to trigger this side effect. If this happens the node will invalidly assume that it did not change and return itself instead if the newly created node (this is for GC reasons). This PR fixes this issue by forcing materialization on streams in `TreeNode` and `QueryPlan`. ## How was this patch tested? Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration test was added to the `PlannerSuite` Author: Herman van Hovell <hvanhov...@databricks.com> Closes #21539 from hvanhovell/SPARK-24500. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/299d297e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/299d297e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/299d297e Branch: refs/heads/master Commit: 299d297e250ca3d46616a97e4256aa9ad6a135e5 Parents: 7703b46 Author: Herman van Hovell <hvanhov...@databricks.com> Authored: Wed Jun 13 07:09:48 2018 -0700 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Wed Jun 13 07:09:48 2018 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/plans/QueryPlan.scala | 1 + .../spark/sql/catalyst/trees/TreeNode.scala | 122 +++++++++---------- .../sql/catalyst/plans/LogicalPlanSuite.scala | 20 ++- .../sql/catalyst/trees/TreeNodeSuite.scala | 25 +++- .../spark/sql/execution/PlannerSuite.scala | 11 +- 5 files changed, 109 insertions(+), 70 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 64cb8c7..e431c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT case Some(value) => Some(recursiveTransform(value)) case m: Map[_, _] => m case d: DataType => d // Avoid unpacking Structs + case stream: Stream[_] => stream.map(recursiveTransform).force case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other case null => null http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 9c7d47f..becfa8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { var changed = false val remainingNewChildren = newChildren.toBuffer val remainingOldChildren = children.toBuffer + def mapTreeNode(node: TreeNode[_]): TreeNode[_] = { + val newChild = remainingNewChildren.remove(0) + val oldChild = remainingOldChildren.remove(0) + if (newChild fastEquals oldChild) { + oldChild + } else { + changed = true + newChild + } + } + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) + case nonChild: AnyRef => nonChild + case null => null + } val newArgs = mapProductIterator { case s: StructType => s // Don't convert struct types to some other type of Seq[StructField] // Handle Seq[TreeNode] in TreeNode parameters. - case s: Seq[_] => s.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - } - case m: Map[_, _] => m.mapValues { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } - case nonChild: AnyRef => nonChild - case null => null - }.view.force // `mapValues` is lazy and we need to force it to materialize - case arg: TreeNode[_] if containsChild(arg) => - val newChild = remainingNewChildren.remove(0) - val oldChild = remainingOldChildren.remove(0) - if (newChild fastEquals oldChild) { - oldChild - } else { - changed = true - newChild - } + case s: Stream[_] => + // Stream is lazy so we need to force materialization + s.map(mapChild).force + case s: Seq[_] => + s.map(mapChild) + case m: Map[_, _] => + // `mapValues` is lazy and we need to force it to materialize + m.mapValues(mapChild).view.force + case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg) case nonChild: AnyRef => nonChild case null => null } @@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { def mapChildren(f: BaseType => BaseType): BaseType = { if (children.nonEmpty) { var changed = false + def mapChild(child: Any): Any = child match { + case arg: TreeNode[_] if containsChild(arg) => + val newChild = f(arg.asInstanceOf[BaseType]) + if (!(newChild fastEquals arg)) { + changed = true + newChild + } else { + arg + } + case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => + val newChild1 = if (containsChild(arg1)) { + f(arg1.asInstanceOf[BaseType]) + } else { + arg1.asInstanceOf[BaseType] + } + + val newChild2 = if (containsChild(arg2)) { + f(arg2.asInstanceOf[BaseType]) + } else { + arg2.asInstanceOf[BaseType] + } + + if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { + changed = true + (newChild1, newChild2) + } else { + tuple + } + case other => other + } + val newArgs = mapProductIterator { case arg: TreeNode[_] if containsChild(arg) => val newChild = f(arg.asInstanceOf[BaseType]) @@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case other => other }.view.force // `mapValues` is lazy and we need to force it to materialize case d: DataType => d // Avoid unpacking Structs - case args: Traversable[_] => args.map { - case arg: TreeNode[_] if containsChild(arg) => - val newChild = f(arg.asInstanceOf[BaseType]) - if (!(newChild fastEquals arg)) { - changed = true - newChild - } else { - arg - } - case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) => - val newChild1 = if (containsChild(arg1)) { - f(arg1.asInstanceOf[BaseType]) - } else { - arg1.asInstanceOf[BaseType] - } - - val newChild2 = if (containsChild(arg2)) { - f(arg2.asInstanceOf[BaseType]) - } else { - arg2.asInstanceOf[BaseType] - } - - if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) { - changed = true - (newChild1, newChild2) - } else { - tuple - } - case other => other - } + case args: Stream[_] => args.map(mapChild).force // Force materialization on stream + case args: Traversable[_] => args.map(mapChild) case nonChild: AnyRef => nonChild case null => null } http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index 1404174..bf569cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Coalesce, Literal, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite { assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) } + + test("transformExpressions works with a Stream") { + val id1 = NamedExpression.newExprId + val id2 = NamedExpression.newExprId + val plan = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(2), "b")(exprId = id2)), + OneRowRelation()) + val result = plan.transformExpressions { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Project(Stream( + Alias(Literal(1), "a")(exprId = id1), + Alias(Literal(3), "b")(exprId = id2)), + OneRowRelation()) + assert(result.sameResult(expected)) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 84d0ba7..b7092f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, JarResource} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.dsl.expressions.DslString import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union} import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, IntegerType, Metadata, NullType, StringType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { @@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite { val right = JsonMethods.parse(rightJson) assert(left == right) } + + test("transform works on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + // Note it is a bit tricky to exhibit the broken behavior. Basically we want to create the + // situation in which the TreeNode.mapChildren function's change detection is not triggered. A + // stream's first element is typically materialized, so in order to not trip the TreeNode change + // detection logic, we should not change the first element in the sequence. + val result = before.transform { + case Literal(v: Int, IntegerType) if v != 1 => + Literal(v + 1, IntegerType) + } + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } + + test("withNewChildren on stream of children") { + val before = Coalesce(Stream(Literal(1), Literal(2))) + val result = before.withNewChildren(Stream(Literal(1), Literal(3))) + val expected = Coalesce(Stream(Literal(1), Literal(3))) + assert(result === expected) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 98a50fb..ed0ff1b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, Repartition, Sort, Union} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext { } assert(rangeExecInZeroPartition.head.outputPartitioning == UnknownPartitioning(0)) } + + test("SPARK-24500: create union with stream of children") { + val df = Union(Stream( + Range(1, 1, 1, 1), + Range(1, 2, 1, 1))) + df.queryExecution.executedPlan.execute() + } } // Used for unit-testing EnsureRequirements --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org