Repository: spark
Updated Branches:
  refs/heads/master 7703b46d2 -> 299d297e2


[SPARK-24500][SQL] Make sure streams are materialized during Tree transforms.

## What changes were proposed in this pull request?
If you construct catalyst trees using `scala.collection.immutable.Stream` you 
can run into situations where valid transformations do not seem to have any 
effect. There are two causes for this behavior:
- `Stream` is evaluated lazily. Note that default implementation will generally 
only evaluate a function for the first element (this makes testing a bit 
tricky).
- `TreeNode` and `QueryPlan` use side effects to detect if a tree has changed. 
Mapping over a stream is lazy and does not need to trigger this side effect. If 
this happens the node will invalidly assume that it did not change and return 
itself instead if the newly created node (this is for GC reasons).

This PR fixes this issue by forcing materialization on streams in `TreeNode` 
and `QueryPlan`.

## How was this patch tested?
Unit tests were added to `TreeNodeSuite` and `LogicalPlanSuite`. An integration 
test was added to the `PlannerSuite`

Author: Herman van Hovell <hvanhov...@databricks.com>

Closes #21539 from hvanhovell/SPARK-24500.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/299d297e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/299d297e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/299d297e

Branch: refs/heads/master
Commit: 299d297e250ca3d46616a97e4256aa9ad6a135e5
Parents: 7703b46
Author: Herman van Hovell <hvanhov...@databricks.com>
Authored: Wed Jun 13 07:09:48 2018 -0700
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Wed Jun 13 07:09:48 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/plans/QueryPlan.scala    |   1 +
 .../spark/sql/catalyst/trees/TreeNode.scala     | 122 +++++++++----------
 .../sql/catalyst/plans/LogicalPlanSuite.scala   |  20 ++-
 .../sql/catalyst/trees/TreeNodeSuite.scala      |  25 +++-
 .../spark/sql/execution/PlannerSuite.scala      |  11 +-
 5 files changed, 109 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 64cb8c7..e431c95 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -119,6 +119,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] 
extends TreeNode[PlanT
       case Some(value) => Some(recursiveTransform(value))
       case m: Map[_, _] => m
       case d: DataType => d // Avoid unpacking Structs
+      case stream: Stream[_] => stream.map(recursiveTransform).force
       case seq: Traversable[_] => seq.map(recursiveTransform)
       case other: AnyRef => other
       case null => null

http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 9c7d47f..becfa8d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -199,44 +199,33 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
     var changed = false
     val remainingNewChildren = newChildren.toBuffer
     val remainingOldChildren = children.toBuffer
+    def mapTreeNode(node: TreeNode[_]): TreeNode[_] = {
+      val newChild = remainingNewChildren.remove(0)
+      val oldChild = remainingOldChildren.remove(0)
+      if (newChild fastEquals oldChild) {
+        oldChild
+      } else {
+        changed = true
+        newChild
+      }
+    }
+    def mapChild(child: Any): Any = child match {
+      case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
+      case nonChild: AnyRef => nonChild
+      case null => null
+    }
     val newArgs = mapProductIterator {
       case s: StructType => s // Don't convert struct types to some other type 
of Seq[StructField]
       // Handle Seq[TreeNode] in TreeNode parameters.
-      case s: Seq[_] => s.map {
-        case arg: TreeNode[_] if containsChild(arg) =>
-          val newChild = remainingNewChildren.remove(0)
-          val oldChild = remainingOldChildren.remove(0)
-          if (newChild fastEquals oldChild) {
-            oldChild
-          } else {
-            changed = true
-            newChild
-          }
-        case nonChild: AnyRef => nonChild
-        case null => null
-      }
-      case m: Map[_, _] => m.mapValues {
-        case arg: TreeNode[_] if containsChild(arg) =>
-          val newChild = remainingNewChildren.remove(0)
-          val oldChild = remainingOldChildren.remove(0)
-          if (newChild fastEquals oldChild) {
-            oldChild
-          } else {
-            changed = true
-            newChild
-          }
-        case nonChild: AnyRef => nonChild
-        case null => null
-      }.view.force // `mapValues` is lazy and we need to force it to 
materialize
-      case arg: TreeNode[_] if containsChild(arg) =>
-        val newChild = remainingNewChildren.remove(0)
-        val oldChild = remainingOldChildren.remove(0)
-        if (newChild fastEquals oldChild) {
-          oldChild
-        } else {
-          changed = true
-          newChild
-        }
+      case s: Stream[_] =>
+        // Stream is lazy so we need to force materialization
+        s.map(mapChild).force
+      case s: Seq[_] =>
+        s.map(mapChild)
+      case m: Map[_, _] =>
+        // `mapValues` is lazy and we need to force it to materialize
+        m.mapValues(mapChild).view.force
+      case arg: TreeNode[_] if containsChild(arg) => mapTreeNode(arg)
       case nonChild: AnyRef => nonChild
       case null => null
     }
@@ -301,6 +290,37 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
   def mapChildren(f: BaseType => BaseType): BaseType = {
     if (children.nonEmpty) {
       var changed = false
+      def mapChild(child: Any): Any = child match {
+        case arg: TreeNode[_] if containsChild(arg) =>
+          val newChild = f(arg.asInstanceOf[BaseType])
+          if (!(newChild fastEquals arg)) {
+            changed = true
+            newChild
+          } else {
+            arg
+          }
+        case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
+          val newChild1 = if (containsChild(arg1)) {
+            f(arg1.asInstanceOf[BaseType])
+          } else {
+            arg1.asInstanceOf[BaseType]
+          }
+
+          val newChild2 = if (containsChild(arg2)) {
+            f(arg2.asInstanceOf[BaseType])
+          } else {
+            arg2.asInstanceOf[BaseType]
+          }
+
+          if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+            changed = true
+            (newChild1, newChild2)
+          } else {
+            tuple
+          }
+        case other => other
+      }
+
       val newArgs = mapProductIterator {
         case arg: TreeNode[_] if containsChild(arg) =>
           val newChild = f(arg.asInstanceOf[BaseType])
@@ -330,36 +350,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] 
extends Product {
           case other => other
         }.view.force // `mapValues` is lazy and we need to force it to 
materialize
         case d: DataType => d // Avoid unpacking Structs
-        case args: Traversable[_] => args.map {
-          case arg: TreeNode[_] if containsChild(arg) =>
-            val newChild = f(arg.asInstanceOf[BaseType])
-            if (!(newChild fastEquals arg)) {
-              changed = true
-              newChild
-            } else {
-              arg
-            }
-          case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
-            val newChild1 = if (containsChild(arg1)) {
-              f(arg1.asInstanceOf[BaseType])
-            } else {
-              arg1.asInstanceOf[BaseType]
-            }
-
-            val newChild2 = if (containsChild(arg2)) {
-              f(arg2.asInstanceOf[BaseType])
-            } else {
-              arg2.asInstanceOf[BaseType]
-            }
-
-            if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
-              changed = true
-              (newChild1, newChild2)
-            } else {
-              tuple
-            }
-          case other => other
-        }
+        case args: Stream[_] => args.map(mapChild).force // Force 
materialization on stream
+        case args: Traversable[_] => args.map(mapChild)
         case nonChild: AnyRef => nonChild
         case null => null
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
index 1404174..bf569cb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.plans
 
 import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.expressions.{Attribute, 
AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, 
AttributeReference, Coalesce, Literal, NamedExpression}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types.IntegerType
 
@@ -101,4 +101,22 @@ class LogicalPlanSuite extends SparkFunSuite {
     assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === 
true)
     assert(TestBinaryRelation(incrementalRelation, 
incrementalRelation).isStreaming)
   }
+
+  test("transformExpressions works with a Stream") {
+    val id1 = NamedExpression.newExprId
+    val id2 = NamedExpression.newExprId
+    val plan = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(2), "b")(exprId = id2)),
+      OneRowRelation())
+    val result = plan.transformExpressions {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    val expected = Project(Stream(
+      Alias(Literal(1), "a")(exprId = id1),
+      Alias(Literal(3), "b")(exprId = id2)),
+      OneRowRelation())
+    assert(result.sameResult(expected))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 84d0ba7..b7092f4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -29,14 +29,14 @@ import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, 
TableIdentifier}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, 
CatalogStorageFormat, CatalogTable, CatalogTableType, FunctionResource, 
JarResource}
+import org.apache.spark.sql.catalyst.catalog._
 import org.apache.spark.sql.catalyst.dsl.expressions.DslString
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
 import org.apache.spark.sql.catalyst.plans.{LeftOuter, NaturalJoin}
 import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Union}
 import org.apache.spark.sql.catalyst.plans.physical.{IdentityBroadcastMode, 
RoundRobinPartitioning, SinglePartition}
-import org.apache.spark.sql.types.{BooleanType, DoubleType, FloatType, 
IntegerType, Metadata, NullType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 
 case class Dummy(optKey: Option[Expression]) extends Expression with 
CodegenFallback {
@@ -574,4 +574,25 @@ class TreeNodeSuite extends SparkFunSuite {
     val right = JsonMethods.parse(rightJson)
     assert(left == right)
   }
+
+  test("transform works on stream of children") {
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    // Note it is a bit tricky to exhibit the broken behavior. Basically we 
want to create the
+    // situation in which the TreeNode.mapChildren function's change detection 
is not triggered. A
+    // stream's first element is typically materialized, so in order to not 
trip the TreeNode change
+    // detection logic, we should not change the first element in the sequence.
+    val result = before.transform {
+      case Literal(v: Int, IntegerType) if v != 1 =>
+        Literal(v + 1, IntegerType)
+    }
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
+
+  test("withNewChildren on stream of children") {
+    val before = Coalesce(Stream(Literal(1), Literal(2)))
+    val result = before.withNewChildren(Stream(Literal(1), Literal(3)))
+    val expected = Coalesce(Stream(Literal(1), Literal(3)))
+    assert(result === expected)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/299d297e/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 98a50fb..ed0ff1b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -21,8 +21,8 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{execution, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, 
LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, 
Repartition, Sort}
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range, 
Repartition, Sort, Union}
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
 import org.apache.spark.sql.execution.exchange.{EnsureRequirements, 
ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec}
@@ -679,6 +679,13 @@ class PlannerSuite extends SharedSQLContext {
     }
     assert(rangeExecInZeroPartition.head.outputPartitioning == 
UnknownPartitioning(0))
   }
+
+  test("SPARK-24500: create union with stream of children") {
+    val df = Union(Stream(
+      Range(1, 1, 1, 1),
+      Range(1, 2, 1, 1)))
+    df.queryExecution.executedPlan.execute()
+  }
 }
 
 // Used for unit-testing EnsureRequirements


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to