This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch branch-2.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-2.3 by this push:
     new abce846  [SPARK-23408][SS][BRANCH-2.3] Synchronize successive AddData 
actions in Streaming*JoinSuite
abce846 is described below

commit abce8460a89ed1557d82100b68724c5ddac9b230
Author: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com>
AuthorDate: Tue Feb 12 15:17:47 2019 -0600

    [SPARK-23408][SS][BRANCH-2.3] Synchronize successive AddData actions in 
Streaming*JoinSuite
    
    ## What changes were proposed in this pull request?
    
    **The best way to review this PR is to ignore whitespace/indent changes. 
Use this link - https://github.com/apache/spark/pull/20650/files?w=1**
    
    The stream-stream join tests add data to multiple sources and expect it all 
to show up in the next batch. But there's a race condition; the new batch might 
trigger when only one of the AddData actions has been reached.
    
    Prior attempt to solve this issue by jose-torres in #20646 attempted to 
simultaneously synchronize on all memory sources together when consecutive 
AddData was found in the actions. However, this carries the risk of deadlock as 
well as unintended modification of stress tests (see the above PR for a 
detailed explanation). Instead, this PR attempts the following.
    
    - A new action called `StreamProgressBlockedActions` that allows multiple 
actions to be executed while the streaming query is blocked from making 
progress. This allows data to be added to multiple sources that are made 
visible simultaneously in the next batch.
    - An alias of `StreamProgressBlockedActions` called `MultiAddData` is 
explicitly used in the `Streaming*JoinSuites` to add data to two memory sources 
simultaneously.
    
    This should avoid unintentional modification of the stress tests (or any 
other test for that matter) while making sure that the flaky tests are 
deterministic.
    
    NOTE: This patch is modified a bit from origin PR (#20650) to cover DSv2 
incompatibility between Spark 2.3 and 2.4: StreamingDataSourceV2Relation is a 
class for 2.3, whereas it is a case class for 2.4
    
    ## How was this patch tested?
    
    Modified test cases in `Streaming*JoinSuites` where there are consecutive 
`AddData` actions.
    
    Closes #23757 from HeartSaVioR/fix-streaming-join-test-flakiness-branch-2.3.
    
    Lead-authored-by: Jungtaek Lim (HeartSaVioR) <kabh...@gmail.com>
    Co-authored-by: Tathagata Das <tathagata.das1...@gmail.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 .../execution/streaming/MicroBatchExecution.scala  |  10 +
 .../apache/spark/sql/streaming/StreamTest.scala    | 474 +++++++++++----------
 .../spark/sql/streaming/StreamingJoinSuite.scala   |  54 +--
 3 files changed, 285 insertions(+), 253 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index dbbb7e7..8bf1dd3 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -492,6 +492,16 @@ class MicroBatchExecution(
     }
   }
 
+  /** Execute a function while locking the stream from making an progress */
+  private[sql] def withProgressLocked(f: => Unit): Unit = {
+    awaitProgressLock.lock()
+    try {
+      f
+    } finally {
+      awaitProgressLock.unlock()
+    }
+  }
+
   private def toJava(scalaOption: Option[OffsetV2]): Optional[OffsetV2] = {
     Optional.ofNullable(scalaOption.orNull)
   }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index c2620d1..adf145e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, 
ExpressionEncoder, Ro
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.physical.AllTuples
 import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
+import 
org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
 import 
org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, 
ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
 import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
@@ -103,6 +103,19 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
       AddDataMemory(source, data)
   }
 
+  /**
+   * Adds data to multiple memory streams such that all the data will be made 
visible in the
+   * same batch. This is applicable only to MicroBatchExecution, as this 
coordination cannot be
+   * performed at the driver in ContinuousExecutions.
+   */
+  object MultiAddData {
+    def apply[A]
+      (source1: MemoryStream[A], data1: A*)(source2: MemoryStream[A], data2: 
A*): StreamAction = {
+      val actions = Seq(AddDataMemory(source1, data1), AddDataMemory(source2, 
data2))
+      StreamProgressLockedActions(actions, desc = actions.mkString("[ ", " | 
", " ]"))
+    }
+  }
+
   /** A trait that can be extended when testing a source. */
   trait AddData extends StreamAction {
     /**
@@ -218,6 +231,19 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
       s"ExpectFailure[${causeClass.getName}, isFatalError: $isFatalError]"
   }
 
+  /**
+   * Performs multiple actions while locking the stream from progressing.
+   * This is applicable only to MicroBatchExecution, as progress of 
ContinuousExecution
+   * cannot be controlled from the driver.
+   */
+  case class StreamProgressLockedActions(actions: Seq[StreamAction], desc: 
String = null)
+    extends StreamAction {
+
+    override def toString(): String = {
+      if (desc != null) desc else super.toString
+    }
+  }
+
   /** Assert that a body is true */
   class Assert(condition: => Boolean, val message: String = "") extends 
StreamAction {
     def run(): Unit = { Assertions.assert(condition) }
@@ -296,6 +322,9 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     val awaiting = new mutable.HashMap[Int, Offset]() // source index -> 
offset to wait for
     val sink = if (useV2Sink) new MemorySinkV2 else new 
MemorySink(stream.schema, outputMode)
     val resetConfValues = mutable.Map[String, Option[String]]()
+    val defaultCheckpointLocation =
+      Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
+    var manualClockExpectedTime = -1L
 
     @volatile
     var streamThreadDeathCause: Throwable = null
@@ -444,243 +473,254 @@ trait StreamTest extends QueryTest with 
SharedSQLContext with TimeLimits with Be
       }
     }
 
-    var manualClockExpectedTime = -1L
-    val defaultCheckpointLocation =
-      Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
-    try {
-      startedTest.foreach { action =>
-        logInfo(s"Processing test stream action: $action")
-        action match {
-          case StartStream(trigger, triggerClock, additionalConfs, 
checkpointLocation) =>
-            verify(currentStream == null, "stream already running")
-            verify(triggerClock.isInstanceOf[SystemClock]
-              || triggerClock.isInstanceOf[StreamManualClock],
-              "Use either SystemClock or StreamManualClock to start the 
stream")
-            if (triggerClock.isInstanceOf[StreamManualClock]) {
-              manualClockExpectedTime = 
triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+    def executeAction(action: StreamAction): Unit = {
+      logInfo(s"Processing test stream action: $action")
+      action match {
+        case StartStream(trigger, triggerClock, additionalConfs, 
checkpointLocation) =>
+          verify(currentStream == null, "stream already running")
+          verify(triggerClock.isInstanceOf[SystemClock]
+            || triggerClock.isInstanceOf[StreamManualClock],
+            "Use either SystemClock or StreamManualClock to start the stream")
+          if (triggerClock.isInstanceOf[StreamManualClock]) {
+            manualClockExpectedTime = 
triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+          }
+          val metadataRoot = 
Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
+
+          additionalConfs.foreach(pair => {
+            val value =
+              if (sparkSession.conf.contains(pair._1)) {
+                Some(sparkSession.conf.get(pair._1))
+              } else None
+            resetConfValues(pair._1) = value
+            sparkSession.conf.set(pair._1, pair._2)
+          })
+
+          lastStream = currentStream
+          currentStream =
+            sparkSession
+              .streams
+              .startQuery(
+                None,
+                Some(metadataRoot),
+                stream,
+                Map(),
+                sink,
+                outputMode,
+                trigger = trigger,
+                triggerClock = triggerClock)
+              .asInstanceOf[StreamingQueryWrapper]
+              .streamingQuery
+          // Wait until the initialization finishes, because some tests need 
to use `logicalPlan`
+          // after starting the query.
+          try {
+            currentStream.awaitInitialization(streamingTimeout.toMillis)
+            currentStream match {
+              case s: ContinuousExecution => eventually("IncrementalExecution 
was not created") {
+                assert(s.lastExecution != null)
+              }
+              case _ =>
             }
-            val metadataRoot = 
Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
+          } catch {
+            case _: StreamingQueryException =>
+              // Ignore the exception. `StopStream` or `ExpectFailure` will 
catch it as well.
+          }
 
-            additionalConfs.foreach(pair => {
-              val value =
-                if (sparkSession.conf.contains(pair._1)) {
-                  Some(sparkSession.conf.get(pair._1))
-                } else None
-              resetConfValues(pair._1) = value
-              sparkSession.conf.set(pair._1, pair._2)
-            })
+        case AdvanceManualClock(timeToAdd) =>
+          verify(currentStream != null,
+                 "can not advance manual clock when a stream is not running")
+          verify(currentStream.triggerClock.isInstanceOf[StreamManualClock],
+                 s"can not advance clock of type 
${currentStream.triggerClock.getClass}")
+          val clock = 
currentStream.triggerClock.asInstanceOf[StreamManualClock]
+          assert(manualClockExpectedTime >= 0)
+
+          // Make sure we don't advance ManualClock too early. See SPARK-16002.
+          eventually("StreamManualClock has not yet entered the waiting 
state") {
+            assert(clock.isStreamWaitingAt(manualClockExpectedTime))
+          }
 
+          clock.advance(timeToAdd)
+          manualClockExpectedTime += timeToAdd
+          verify(clock.getTimeMillis() === manualClockExpectedTime,
+            s"Unexpected clock time after updating: " +
+              s"expecting $manualClockExpectedTime, current 
${clock.getTimeMillis()}")
+
+        case StopStream =>
+          verify(currentStream != null, "can not stop a stream that is not 
running")
+          try failAfter(streamingTimeout) {
+            currentStream.stop()
+            verify(!currentStream.queryExecutionThread.isAlive,
+              s"microbatch thread not stopped")
+            verify(!currentStream.isActive,
+              "query.isActive() is false even after stopping")
+            verify(currentStream.exception.isEmpty,
+              s"query.exception() is not empty after clean stop: " +
+                currentStream.exception.map(_.toString()).getOrElse(""))
+          } catch {
+            case _: InterruptedException =>
+            case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+              failTest(
+                "Timed out while stopping and waiting for microbatchthread to 
terminate.", e)
+            case t: Throwable =>
+              failTest("Error while stopping stream", t)
+          } finally {
             lastStream = currentStream
-            currentStream =
-              sparkSession
-                .streams
-                .startQuery(
-                  None,
-                  Some(metadataRoot),
-                  stream,
-                  Map(),
-                  sink,
-                  outputMode,
-                  trigger = trigger,
-                  triggerClock = triggerClock)
-                .asInstanceOf[StreamingQueryWrapper]
-                .streamingQuery
-            // Wait until the initialization finishes, because some tests need 
to use `logicalPlan`
-            // after starting the query.
-            try {
-              currentStream.awaitInitialization(streamingTimeout.toMillis)
-              currentStream match {
-                case s: ContinuousExecution => 
eventually("IncrementalExecution was not created") {
-                  assert(s.lastExecution != null)
-                }
-                case _ =>
-              }
-            } catch {
-              case _: StreamingQueryException =>
-                // Ignore the exception. `StopStream` or `ExpectFailure` will 
catch it as well.
-            }
+            currentStream = null
+          }
 
-          case AdvanceManualClock(timeToAdd) =>
-            verify(currentStream != null,
-                   "can not advance manual clock when a stream is not running")
-            verify(currentStream.triggerClock.isInstanceOf[StreamManualClock],
-                   s"can not advance clock of type 
${currentStream.triggerClock.getClass}")
-            val clock = 
currentStream.triggerClock.asInstanceOf[StreamManualClock]
-            assert(manualClockExpectedTime >= 0)
-
-            // Make sure we don't advance ManualClock too early. See 
SPARK-16002.
-            eventually("StreamManualClock has not yet entered the waiting 
state") {
-              assert(clock.isStreamWaitingAt(manualClockExpectedTime))
+        case ef: ExpectFailure[_] =>
+          verify(currentStream != null, "can not expect failure when stream is 
not running")
+          try failAfter(streamingTimeout) {
+            val thrownException = intercept[StreamingQueryException] {
+              currentStream.awaitTermination()
             }
-
-            clock.advance(timeToAdd)
-            manualClockExpectedTime += timeToAdd
-            verify(clock.getTimeMillis() === manualClockExpectedTime,
-              s"Unexpected clock time after updating: " +
-                s"expecting $manualClockExpectedTime, current 
${clock.getTimeMillis()}")
-
-          case StopStream =>
-            verify(currentStream != null, "can not stop a stream that is not 
running")
-            try failAfter(streamingTimeout) {
-              currentStream.stop()
-              verify(!currentStream.queryExecutionThread.isAlive,
-                s"microbatch thread not stopped")
-              verify(!currentStream.isActive,
-                "query.isActive() is false even after stopping")
-              verify(currentStream.exception.isEmpty,
-                s"query.exception() is not empty after clean stop: " +
-                  currentStream.exception.map(_.toString()).getOrElse(""))
-            } catch {
-              case _: InterruptedException =>
-              case e: org.scalatest.exceptions.TestFailedDueToTimeoutException 
=>
-                failTest(
-                  "Timed out while stopping and waiting for microbatchthread 
to terminate.", e)
-              case t: Throwable =>
-                failTest("Error while stopping stream", t)
-            } finally {
-              lastStream = currentStream
-              currentStream = null
+            eventually("microbatch thread not stopped after termination with 
failure") {
+              assert(!currentStream.queryExecutionThread.isAlive)
             }
+            verify(currentStream.exception === Some(thrownException),
+              s"incorrect exception returned by query.exception()")
+
+            val exception = currentStream.exception.get
+            verify(exception.cause.getClass === ef.causeClass,
+              "incorrect cause in exception returned by query.exception()\n" +
+                s"\tExpected: ${ef.causeClass}\n\tReturned: 
${exception.cause.getClass}")
+            if (ef.isFatalError) {
+              // This is a fatal error, `streamThreadDeathCause` should be set 
to this error in
+              // UncaughtExceptionHandler.
+              verify(streamThreadDeathCause != null &&
+                streamThreadDeathCause.getClass === ef.causeClass,
+                "UncaughtExceptionHandler didn't receive the correct error\n" +
+                  s"\tExpected: ${ef.causeClass}\n\tReturned: 
$streamThreadDeathCause")
+              streamThreadDeathCause = null
+            }
+            ef.assertFailure(exception.getCause)
+          } catch {
+            case _: InterruptedException =>
+            case e: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
+              failTest("Timed out while waiting for failure", e)
+            case t: Throwable =>
+              failTest("Error while checking stream failure", t)
+          } finally {
+            lastStream = currentStream
+            currentStream = null
+          }
 
-          case ef: ExpectFailure[_] =>
-            verify(currentStream != null, "can not expect failure when stream 
is not running")
-            try failAfter(streamingTimeout) {
-              val thrownException = intercept[StreamingQueryException] {
-                currentStream.awaitTermination()
-              }
-              eventually("microbatch thread not stopped after termination with 
failure") {
-                assert(!currentStream.queryExecutionThread.isAlive)
+        case a: AssertOnQuery =>
+          verify(currentStream != null || lastStream != null,
+            "cannot assert when no stream has been started")
+          val streamToAssert = Option(currentStream).getOrElse(lastStream)
+          try {
+            verify(a.condition(streamToAssert), s"Assert on query failed: 
${a.message}")
+          } catch {
+            case NonFatal(e) =>
+              failTest(s"Assert on query failed: ${a.message}", e)
+          }
+
+        case a: Assert =>
+          val streamToAssert = Option(currentStream).getOrElse(lastStream)
+          verify({ a.run(); true }, s"Assert failed: ${a.message}")
+
+        case a: AddData =>
+          try {
+
+            // If the query is running with manual clock, then wait for the 
stream execution
+            // thread to start waiting for the clock to increment. This is 
needed so that we
+            // are adding data when there is no trigger that is active. This 
would ensure that
+            // the data gets deterministically added to the next batch 
triggered after the manual
+            // clock is incremented in following AdvanceManualClock. This 
avoid race conditions
+            // between the test thread and the stream execution thread in 
tests using manual
+            // clock.
+            if (currentStream != null &&
+                currentStream.triggerClock.isInstanceOf[StreamManualClock]) {
+              val clock = 
currentStream.triggerClock.asInstanceOf[StreamManualClock]
+              eventually("Error while synchronizing with manual clock before 
adding data") {
+                if (currentStream.isActive) {
+                  assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
+                }
               }
-              verify(currentStream.exception === Some(thrownException),
-                s"incorrect exception returned by query.exception()")
-
-              val exception = currentStream.exception.get
-              verify(exception.cause.getClass === ef.causeClass,
-                "incorrect cause in exception returned by query.exception()\n" 
+
-                  s"\tExpected: ${ef.causeClass}\n\tReturned: 
${exception.cause.getClass}")
-              if (ef.isFatalError) {
-                // This is a fatal error, `streamThreadDeathCause` should be 
set to this error in
-                // UncaughtExceptionHandler.
-                verify(streamThreadDeathCause != null &&
-                  streamThreadDeathCause.getClass === ef.causeClass,
-                  "UncaughtExceptionHandler didn't receive the correct 
error\n" +
-                    s"\tExpected: ${ef.causeClass}\n\tReturned: 
$streamThreadDeathCause")
-                streamThreadDeathCause = null
+              if (!currentStream.isActive) {
+                failTest("Query terminated while synchronizing with manual 
clock")
               }
-              ef.assertFailure(exception.getCause)
-            } catch {
-              case _: InterruptedException =>
-              case e: org.scalatest.exceptions.TestFailedDueToTimeoutException 
=>
-                failTest("Timed out while waiting for failure", e)
-              case t: Throwable =>
-                failTest("Error while checking stream failure", t)
-            } finally {
-              lastStream = currentStream
-              currentStream = null
             }
-
-          case a: AssertOnQuery =>
-            verify(currentStream != null || lastStream != null,
-              "cannot assert when no stream has been started")
-            val streamToAssert = Option(currentStream).getOrElse(lastStream)
-            try {
-              verify(a.condition(streamToAssert), s"Assert on query failed: 
${a.message}")
-            } catch {
-              case NonFatal(e) =>
-                failTest(s"Assert on query failed: ${a.message}", e)
+            // Add data
+            val queryToUse = Option(currentStream).orElse(Option(lastStream))
+            val (source, offset) = a.addData(queryToUse)
+
+            def findSourceIndex(plan: LogicalPlan): Option[Int] = {
+              plan
+                .collect {
+                  case StreamingExecutionRelation(s, _) => s
+                  case rel: StreamingDataSourceV2Relation => rel.reader
+                }
+                .zipWithIndex
+                .find(_._1 == source)
+                .map(_._2)
             }
 
-          case a: Assert =>
-            val streamToAssert = Option(currentStream).getOrElse(lastStream)
-            verify({ a.run(); true }, s"Assert failed: ${a.message}")
-
-          case a: AddData =>
-            try {
-
-              // If the query is running with manual clock, then wait for the 
stream execution
-              // thread to start waiting for the clock to increment. This is 
needed so that we
-              // are adding data when there is no trigger that is active. This 
would ensure that
-              // the data gets deterministically added to the next batch 
triggered after the manual
-              // clock is incremented in following AdvanceManualClock. This 
avoid race conditions
-              // between the test thread and the stream execution thread in 
tests using manual
-              // clock.
-              if (currentStream != null &&
-                  currentStream.triggerClock.isInstanceOf[StreamManualClock]) {
-                val clock = 
currentStream.triggerClock.asInstanceOf[StreamManualClock]
-                eventually("Error while synchronizing with manual clock before 
adding data") {
-                  if (currentStream.isActive) {
-                    assert(clock.isStreamWaitingAt(clock.getTimeMillis()))
-                  }
+            // Try to find the index of the source to which data was added. 
Either get the index
+            // from the current active query or the original input logical 
plan.
+            val sourceIndex =
+              queryToUse.flatMap { query =>
+                findSourceIndex(query.logicalPlan)
+              }.orElse {
+                findSourceIndex(stream.logicalPlan)
+              }.orElse {
+                queryToUse.flatMap { q =>
+                  findSourceIndex(q.lastExecution.logical)
                 }
-                if (!currentStream.isActive) {
-                  failTest("Query terminated while synchronizing with manual 
clock")
-                }
-              }
-              // Add data
-              val queryToUse = Option(currentStream).orElse(Option(lastStream))
-              val (source, offset) = a.addData(queryToUse)
-
-              def findSourceIndex(plan: LogicalPlan): Option[Int] = {
-                plan
-                  .collect {
-                    case StreamingExecutionRelation(s, _) => s
-                    case DataSourceV2Relation(_, r) => r
-                  }
-                  .zipWithIndex
-                  .find(_._1 == source)
-                  .map(_._2)
+              }.getOrElse {
+                throw new IllegalArgumentException(
+                  "Could not find index of the source to which data was added")
               }
 
-              // Try to find the index of the source to which data was added. 
Either get the index
-              // from the current active query or the original input logical 
plan.
-              val sourceIndex =
-                queryToUse.flatMap { query =>
-                  findSourceIndex(query.logicalPlan)
-                }.orElse {
-                  findSourceIndex(stream.logicalPlan)
-                }.orElse {
-                  queryToUse.flatMap { q =>
-                    findSourceIndex(q.lastExecution.logical)
-                  }
-                }.getOrElse {
-                  throw new IllegalArgumentException(
-                    "Could not find index of the source to which data was 
added")
-                }
+            // Store the expected offset of added data to wait for it later
+            awaiting.put(sourceIndex, offset)
+          } catch {
+            case NonFatal(e) =>
+              failTest("Error adding data", e)
+          }
 
-              // Store the expected offset of added data to wait for it later
-              awaiting.put(sourceIndex, offset)
-            } catch {
-              case NonFatal(e) =>
-                failTest("Error adding data", e)
-            }
+        case e: ExternalAction =>
+          e.runAction()
 
-          case e: ExternalAction =>
-            e.runAction()
+        case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
+          val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+          QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
+            error => failTest(error)
+          }
 
-          case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
-            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
-            QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
-              error => failTest(error)
-            }
+        case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
+          val sparkAnswer = currentStream match {
+            case null => fetchStreamAnswer(lastStream, lastOnly)
+            case s => fetchStreamAnswer(s, lastOnly)
+          }
+          QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
+            error => failTest(error)
+          }
 
-          case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
-            val sparkAnswer = currentStream match {
-              case null => fetchStreamAnswer(lastStream, lastOnly)
-              case s => fetchStreamAnswer(s, lastOnly)
-            }
-            QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
-              error => failTest(error)
-            }
+        case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
+          val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+          try {
+            globalCheckFunction(sparkAnswer)
+          } catch {
+            case e: Throwable => failTest(e.toString)
+          }
+      }
+      pos += 1
+    }
 
-          case CheckAnswerRowsByFunc(globalCheckFunction, lastOnly) =>
-            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
-            try {
-              globalCheckFunction(sparkAnswer)
-            } catch {
-              case e: Throwable => failTest(e.toString)
-            }
-        }
-        pos += 1
+    try {
+      startedTest.foreach {
+        case StreamProgressLockedActions(actns, _) =>
+          // Perform actions while holding the stream from progressing
+          assert(currentStream != null,
+            s"Cannot perform stream-progress-locked actions $actns when query 
is not active")
+          assert(currentStream.isInstanceOf[MicroBatchExecution],
+            s"Cannot perform stream-progress-locked actions on non-microbatch 
queries")
+          currentStream.asInstanceOf[MicroBatchExecution].withProgressLocked {
+            actns.foreach(executeAction)
+          }
+
+        case action: StreamAction => executeAction(action)
       }
       if (streamThreadDeathCause != null) {
         failTest("Stream Thread Died", streamThreadDeathCause)
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 657a363..ffcb417 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -476,15 +476,13 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
         .select(left("key"), left("window.end").cast("long"), 'leftValue, 
'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 3, 4, 5),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
       // The left rows with leftValue <= 4 should generate their outer join 
row now and
       // not get added to the state.
       CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, 
null)),
       assertNumStateRows(total = 4, updated = 4),
       // We shouldn't get more outer join rows when the watermark advances.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch((20, 30, 40, "60"))
@@ -507,15 +505,13 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
       .select(left("key"), left("window.end").cast("long"), 'leftValue, 
'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 3, 4, 5),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
       // The right rows with value <= 7 should never be added to the state.
       CheckLastBatch(Row(3, 10, 6, "9")),
       assertNumStateRows(total = 4, updated = 4),
       // When the watermark advances, we get the outer join rows just as we 
would if they
       // were added but didn't match the full join condition.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 
10, null))
@@ -538,15 +534,13 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
       .select(right("key"), right("window.end").cast("long"), 'leftValue, 
'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 3, 4, 5),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 3, 4, 5),
       // The left rows with value <= 4 should never be added to the state.
       CheckLastBatch(Row(3, 10, 6, "9")),
       assertNumStateRows(total = 4, updated = 4),
       // When the watermark advances, we get the outer join rows just as we 
would if they
       // were added but didn't match the full join condition.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, 
null, "15"))
@@ -569,15 +563,13 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
       .select(right("key"), right("window.end").cast("long"), 'leftValue, 
'rightValue)
 
     testStream(joined)(
-      AddData(leftInput, 3, 4, 5),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 3, 4, 5)(rightInput, 1, 2, 3),
       // The right rows with rightValue <= 7 should generate their outer join 
row now and
       // not get added to the state.
       CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, 
null, "6")),
       assertNumStateRows(total = 4, updated = 4),
       // We shouldn't get more outer join rows when the watermark advances.
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       AddData(rightInput, 20),
       CheckLastBatch((20, 30, 40, "60"))
@@ -589,13 +581,11 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
 
     testStream(joined)(
       // Test inner part of the join.
-      AddData(leftInput, 1, 2, 3, 4, 5),
-      AddData(rightInput, 3, 4, 5, 6, 7),
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
       CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
       // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
       // nulls won't show up until the next batch after the watermark advances.
-      AddData(leftInput, 21),
-      AddData(rightInput, 22),
+      MultiAddData(leftInput, 21)(rightInput, 22),
       CheckLastBatch(),
       assertNumStateRows(total = 12, updated = 2),
       AddData(leftInput, 22),
@@ -609,13 +599,11 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
 
     testStream(joined)(
       // Test inner part of the join.
-      AddData(leftInput, 1, 2, 3, 4, 5),
-      AddData(rightInput, 3, 4, 5, 6, 7),
+      MultiAddData(leftInput, 1, 2, 3, 4, 5)(rightInput, 3, 4, 5, 6, 7),
       CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)),
       // Old state doesn't get dropped until the batch *after* it gets 
introduced, so the
       // nulls won't show up until the next batch after the watermark advances.
-      AddData(leftInput, 21),
-      AddData(rightInput, 22),
+      MultiAddData(leftInput, 21)(rightInput, 22),
       CheckLastBatch(),
       assertNumStateRows(total = 12, updated = 2),
       AddData(leftInput, 22),
@@ -690,11 +678,9 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
 
     testStream(joined)(
       // leftValue <= 10 should generate outer join rows even though it 
matches right keys
-      AddData(leftInput, 1, 2, 3),
-      AddData(rightInput, 1, 2, 3),
+      MultiAddData(leftInput, 1, 2, 3)(rightInput, 1, 2, 3),
       CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, 
null)),
-      AddData(leftInput, 20),
-      AddData(rightInput, 21),
+      MultiAddData(leftInput, 20)(rightInput, 21),
       CheckLastBatch(),
       assertNumStateRows(total = 5, updated = 2),
       AddData(rightInput, 20),
@@ -702,22 +688,18 @@ class StreamingOuterJoinSuite extends StreamTest with 
StateStoreMetricsTest with
         Row(20, 30, 40, 60)),
       assertNumStateRows(total = 3, updated = 1),
       // leftValue and rightValue both satisfying condition should not 
generate outer join rows
-      AddData(leftInput, 40, 41),
-      AddData(rightInput, 40, 41),
+      MultiAddData(leftInput, 40, 41)(rightInput, 40, 41),
       CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)),
-      AddData(leftInput, 70),
-      AddData(rightInput, 71),
+      MultiAddData(leftInput, 70)(rightInput, 71),
       CheckLastBatch(),
       assertNumStateRows(total = 6, updated = 2),
       AddData(rightInput, 70),
       CheckLastBatch((70, 80, 140, 210)),
       assertNumStateRows(total = 3, updated = 1),
       // rightValue between 300 and 1000 should generate outer join rows even 
though it matches left
-      AddData(leftInput, 101, 102, 103),
-      AddData(rightInput, 101, 102, 103),
+      MultiAddData(leftInput, 101, 102, 103)(rightInput, 101, 102, 103),
       CheckLastBatch(),
-      AddData(leftInput, 1000),
-      AddData(rightInput, 1001),
+      MultiAddData(leftInput, 1000)(rightInput, 1001),
       CheckLastBatch(),
       assertNumStateRows(total = 8, updated = 2),
       AddData(rightInput, 1000),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to