Repository: spark
Updated Branches:
  refs/heads/branch-2.3 3f5955aa0 -> 507cff246


[SPARK-23827][SS] StreamingJoinExec should ensure that input data is 
partitioned into specific number of partitions

## What changes were proposed in this pull request?

Currently, the requiredChildDistribution does not specify the partitions. This 
can cause the weird corner cases where the child's distribution is 
`SinglePartition` which satisfies the required distribution of 
`ClusterDistribution(no-num-partition-requirement)`, thus eliminating the 
shuffle needed to repartition input data into the required number of partitions 
(i.e. same as state stores). That can lead to "file not found" errors on the 
state store delta files as the micro-batch-with-no-shuffle will not run certain 
tasks and therefore not generate the expected state store delta files.

This PR adds the required constraint on the number of partitions.

## How was this patch tested?
Modified test harness to always check that ANY stateful operator should have a 
constraint on the number of partitions. As part of that, the existing opt-in 
checks on child output partitioning were removed, as they are redundant.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #20941 from tdas/SPARK-23827.

(cherry picked from commit 15298b99ac8944e781328423289586176cf824d7)
Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/507cff24
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/507cff24
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/507cff24

Branch: refs/heads/branch-2.3
Commit: 507cff246cd9e15a418d67b66bf762be4ae71c67
Parents: 3f5955a
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Fri Mar 30 16:48:26 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Fri Mar 30 16:48:55 2018 -0700

----------------------------------------------------------------------
 .../streaming/IncrementalExecution.scala        |  2 +-
 .../StreamingSymmetricHashJoinExec.scala        |  3 +-
 .../spark/sql/streaming/DeduplicateSuite.scala  |  8 +---
 .../streaming/FlatMapGroupsWithStateSuite.scala |  5 +-
 .../sql/streaming/StatefulOperatorTest.scala    | 49 --------------------
 .../apache/spark/sql/streaming/StreamTest.scala | 19 ++++++++
 .../streaming/StreamingAggregationSuite.scala   |  4 +-
 7 files changed, 25 insertions(+), 65 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index a10ed5f..1a83c88 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -62,7 +62,7 @@ class IncrementalExecution(
       StreamingDeduplicationStrategy :: Nil
   }
 
-  private val numStateStores = 
offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
+  private[sql] val numStateStores = 
offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key)
     .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter)
     .getOrElse(sparkSession.sessionState.conf.numShufflePartitions)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index c351f65..fa7c8ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -167,7 +167,8 @@ case class StreamingSymmetricHashJoinExec(
   val nullRight = new 
GenericInternalRow(right.output.map(_.withNullability(true)).length)
 
   override def requiredChildDistribution: Seq[Distribution] =
-    ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+    ClusteredDistribution(leftKeys, stateInfo.map(_.numPartitions)) ::
+      ClusteredDistribution(rightKeys, stateInfo.map(_.numPartitions)) :: Nil
 
   override def output: Seq[Attribute] = joinType match {
     case _: InnerLike => left.output ++ right.output

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
index caf2bab..0088b64 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala
@@ -25,9 +25,7 @@ import 
org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplic
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.functions._
 
-class DeduplicateSuite extends StateStoreMetricsTest
-    with BeforeAndAfterAll
-    with StatefulOperatorTest {
+class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll {
 
   import testImplicits._
 
@@ -44,8 +42,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
       AddData(inputData, "a"),
       CheckLastBatch("a"),
       assertNumStateRows(total = 1, updated = 1),
-      AssertOnQuery(sq =>
-        checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, 
Seq("value"))),
       AddData(inputData, "a"),
       CheckLastBatch(),
       assertNumStateRows(total = 1, updated = 0),
@@ -63,8 +59,6 @@ class DeduplicateSuite extends StateStoreMetricsTest
       AddData(inputData, "a" -> 1),
       CheckLastBatch("a" -> 1),
       assertNumStateRows(total = 1, updated = 1),
-      AssertOnQuery(sq =>
-        checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, 
Seq("_1"))),
       AddData(inputData, "a" -> 2), // Dropped
       CheckLastBatch(),
       assertNumStateRows(total = 1, updated = 0),

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
index de2b516..b1416bf 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala
@@ -42,8 +42,7 @@ case class RunningCount(count: Long)
 case class Result(key: Long, count: Int)
 
 class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest
-    with BeforeAndAfterAll
-    with StatefulOperatorTest {
+    with BeforeAndAfterAll {
 
   import testImplicits._
   import GroupStateImpl._
@@ -618,8 +617,6 @@ class FlatMapGroupsWithStateSuite extends 
StateStoreMetricsTest
       AddData(inputData, "a"),
       CheckLastBatch(("a", "1")),
       assertNumStateRows(total = 1, updated = 1),
-      AssertOnQuery(sq => 
checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec](
-        sq, Seq("value"))),
       AddData(inputData, "a", "b"),
       CheckLastBatch(("a", "2"), ("b", "1")),
       assertNumStateRows(total = 2, updated = 2),

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
deleted file mode 100644
index 4514227..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.streaming
-
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.streaming._
-
-trait StatefulOperatorTest {
-  /**
-   * Check that the output partitioning of a child operator of a Stateful 
operator satisfies the
-   * distribution that we expect for our Stateful operator.
-   */
-  protected def checkChildOutputHashPartitioning[T <: StatefulOperator](
-      sq: StreamingQuery,
-      colNames: Seq[String]): Boolean = {
-    val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output
-    val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions
-    val groupingAttr = attr.filter(a => colNames.contains(a.name))
-    checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, 
partitions))
-  }
-
-  /**
-   * Check that the output partitioning of a child operator of a Stateful 
operator satisfies the
-   * distribution that we expect for our Stateful operator.
-   */
-  protected def checkChildOutputPartitioning[T <: StatefulOperator](
-      sq: StreamingQuery,
-      expectedPartitioning: Partitioning): Boolean = {
-    val operator = sq.asInstanceOf[StreamExecution].lastExecution
-      .executedPlan.collect { case p: T => p }
-    operator.head.children.forall(
-      _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d643356..c2620d1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -37,6 +37,7 @@ import org.apache.spark.SparkEnv
 import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, 
RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.physical.AllTuples
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
@@ -415,6 +416,24 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
         }
       }
 
+      val lastExecution = currentStream.lastExecution
+      if (currentStream.isInstanceOf[MicroBatchExecution] && lastExecution != 
null) {
+        // Verify if stateful operators have correct metadata and distribution
+        // This can often catch hard to debug errors when developing stateful 
operators
+        lastExecution.executedPlan.collect { case s: StatefulOperator => s 
}.foreach { s =>
+          
assert(s.stateInfo.map(_.numPartitions).contains(lastExecution.numStateStores))
+          s.requiredChildDistribution.foreach { d =>
+            withClue(s"$s specifies incorrect # partitions in 
requiredChildDistribution $d") {
+              assert(d.requiredNumPartitions.isDefined)
+              assert(d.requiredNumPartitions.get >= 1)
+              if (d != AllTuples) {
+                assert(d.requiredNumPartitions.get == 
s.stateInfo.get.numPartitions)
+              }
+            }
+          }
+        }
+      }
+
       val (latestBatchData, allData) = sink match {
         case s: MemorySink => (s.latestBatchData, s.allData)
         case s: MemorySinkV2 => (s.latestBatchData, s.allData)

http://git-wip-us.apache.org/repos/asf/spark/blob/507cff24/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
index 97e0651..1cae8cb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala
@@ -44,7 +44,7 @@ object FailureSingleton {
 }
 
 class StreamingAggregationSuite extends StateStoreMetricsTest
-    with BeforeAndAfterAll with Assertions with StatefulOperatorTest {
+    with BeforeAndAfterAll with Assertions {
 
   override def afterAll(): Unit = {
     super.afterAll()
@@ -281,8 +281,6 @@ class StreamingAggregationSuite extends 
StateStoreMetricsTest
       AddData(inputData, 0L, 5L, 5L, 10L),
       AdvanceManualClock(10 * 1000),
       CheckLastBatch((0L, 1), (5L, 2), (10L, 1)),
-      AssertOnQuery(sq =>
-        checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, 
Seq("value"))),
 
       // advance clock to 20 seconds, should retain keys >= 10
       AddData(inputData, 15L, 15L, 20L),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to