Repository: spark
Updated Branches:
  refs/heads/master d23dc5b8e -> 8941a4abc


http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
index 94c5dd6..972248d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala
@@ -25,6 +25,8 @@ import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, 
Complete, Update}
 import org.apache.spark.sql.execution.streaming.Sink
 import org.apache.spark.sql.sources.v2.{ContinuousWriteSupport, DataSourceV2, 
DataSourceV2Options, MicroBatchWriteSupport}
@@ -177,3 +179,14 @@ class MemoryDataWriter(partition: Int, outputMode: 
OutputMode)
 
   override def abort(): Unit = {}
 }
+
+
+/**
+ * Used to query the data that has been written into a [[MemorySink]].
+ */
+case class MemoryPlanV2(sink: MemorySinkV2, override val output: 
Seq[Attribute]) extends LeafNode {
+  private val sizePerRow = output.map(_.dataType.defaultSize).sum
+
+  override def computeStats(): Statistics = Statistics(sizePerRow * 
sink.allData.size)
+}
+

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 41aa02c..f17935e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -26,8 +26,10 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, 
SparkSession}
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.StreamingRelation
+import org.apache.spark.sql.execution.streaming.{StreamingRelation, 
StreamingRelationV2}
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceV2Options, MicroBatchReadSupport}
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
 
 /**
  * Interface used to load a streaming `Dataset` from external storage systems 
(e.g. file systems,
@@ -153,13 +155,33 @@ final class DataStreamReader private[sql](sparkSession: 
SparkSession) extends Lo
         "read files of Hive data source directly.")
     }
 
-    val dataSource =
-      DataSource(
-        sparkSession,
-        userSpecifiedSchema = userSpecifiedSchema,
-        className = source,
-        options = extraOptions.toMap)
-    Dataset.ofRows(sparkSession, StreamingRelation(dataSource))
+    val ds = DataSource.lookupDataSource(source, 
sparkSession.sqlContext.conf).newInstance()
+    val options = new DataSourceV2Options(extraOptions.asJava)
+    // We need to generate the V1 data source so we can pass it to the V2 
relation as a shim.
+    // We can't be sure at this point whether we'll actually want to use V2, 
since we don't know the
+    // writer or whether the query is continuous.
+    val v1DataSource = DataSource(
+      sparkSession,
+      userSpecifiedSchema = userSpecifiedSchema,
+      className = source,
+      options = extraOptions.toMap)
+    ds match {
+      case s: ContinuousReadSupport =>
+        val tempReader = s.createContinuousReader(
+          java.util.Optional.ofNullable(userSpecifiedSchema.orNull),
+          Utils.createTempDir(namePrefix = 
s"temporaryReader").getCanonicalPath,
+          options)
+        // Generate the V1 node to catch errors thrown within generation.
+        StreamingRelation(v1DataSource)
+        Dataset.ofRows(
+          sparkSession,
+          StreamingRelationV2(
+            s, source, extraOptions.toMap,
+            tempReader.readSchema().toAttributes, v1DataSource)(sparkSession))
+      case _ =>
+        // Code path for data source v1.
+        Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource))
+    }
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 0be69b9..db588ae 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -26,7 +26,9 @@ import org.apache.spark.sql.{AnalysisException, Dataset, 
ForeachWriter}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
 import org.apache.spark.sql.execution.command.DDLUtils
 import org.apache.spark.sql.execution.datasources.DataSource
-import org.apache.spark.sql.execution.streaming.{ForeachSink, MemoryPlan, 
MemorySink}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
+import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, 
MemorySinkV2}
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems 
(e.g. file systems,
@@ -240,14 +242,23 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
       if (extraOptions.get("queryName").isEmpty) {
         throw new AnalysisException("queryName must be specified for memory 
sink")
       }
-      val sink = new MemorySink(df.schema, outputMode)
-      val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink))
+      val (sink, resultDf) = trigger match {
+        case _: ContinuousTrigger =>
+          val s = new MemorySinkV2()
+          val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, 
df.schema.toAttributes))
+          (s, r)
+        case _ =>
+          val s = new MemorySink(df.schema, outputMode)
+          val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s))
+          (s, r)
+      }
       val chkpointLoc = extraOptions.get("checkpointLocation")
       val recoverFromChkpoint = outputMode == OutputMode.Complete()
       val query = 
df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
         chkpointLoc,
         df,
+        extraOptions.toMap,
         sink,
         outputMode,
         useTempCheckpointLocation = true,
@@ -262,6 +273,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),
         df,
+        extraOptions.toMap,
         sink,
         outputMode,
         useTempCheckpointLocation = true,
@@ -277,6 +289,7 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),
         df,
+        extraOptions.toMap,
         dataSource.createSink(outputMode),
         outputMode,
         useTempCheckpointLocation = source == "console",

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index 555d6e2..e808ffa 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -29,8 +29,10 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{AnalysisException, DataFrame, SparkSession}
 import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
 import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
 import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.v2.ContinuousWriteSupport
 import org.apache.spark.util.{Clock, SystemClock, Utils}
 
 /**
@@ -188,7 +190,8 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
       userSpecifiedName: Option[String],
       userSpecifiedCheckpointLocation: Option[String],
       df: DataFrame,
-      sink: Sink,
+      extraOptions: Map[String, String],
+      sink: BaseStreamingSink,
       outputMode: OutputMode,
       useTempCheckpointLocation: Boolean,
       recoverFromCheckpointLocation: Boolean,
@@ -237,16 +240,32 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
           "is not supported in streaming DataFrames/Datasets and will be 
disabled.")
     }
 
-    new StreamingQueryWrapper(new MicroBatchExecution(
-      sparkSession,
-      userSpecifiedName.orNull,
-      checkpointLocation,
-      analyzedPlan,
-      sink,
-      trigger,
-      triggerClock,
-      outputMode,
-      deleteCheckpointOnStop))
+    sink match {
+      case v1Sink: Sink =>
+        new StreamingQueryWrapper(new MicroBatchExecution(
+          sparkSession,
+          userSpecifiedName.orNull,
+          checkpointLocation,
+          analyzedPlan,
+          v1Sink,
+          trigger,
+          triggerClock,
+          outputMode,
+          deleteCheckpointOnStop))
+      case v2Sink: ContinuousWriteSupport =>
+        UnsupportedOperationChecker.checkForContinuous(analyzedPlan, 
outputMode)
+        new StreamingQueryWrapper(new ContinuousExecution(
+          sparkSession,
+          userSpecifiedName.orNull,
+          checkpointLocation,
+          analyzedPlan,
+          v2Sink,
+          trigger,
+          triggerClock,
+          outputMode,
+          extraOptions,
+          deleteCheckpointOnStop))
+    }
   }
 
   /**
@@ -269,7 +288,8 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
       userSpecifiedName: Option[String],
       userSpecifiedCheckpointLocation: Option[String],
       df: DataFrame,
-      sink: Sink,
+      extraOptions: Map[String, String],
+      sink: BaseStreamingSink,
       outputMode: OutputMode,
       useTempCheckpointLocation: Boolean = false,
       recoverFromCheckpointLocation: Boolean = true,
@@ -279,6 +299,7 @@ class StreamingQueryManager private[sql] (sparkSession: 
SparkSession) extends Lo
       userSpecifiedName,
       userSpecifiedCheckpointLocation,
       df,
+      extraOptions,
       sink,
       outputMode,
       useTempCheckpointLocation,

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index fcaca3d..9fb8be4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -297,31 +297,47 @@ object QueryTest {
     })
   }
 
+  private def genError(
+      expectedAnswer: Seq[Row],
+      sparkAnswer: Seq[Row],
+      isSorted: Boolean = false): String = {
+    val getRowType: Option[Row] => String = row =>
+      row.map(row =>
+        if (row.schema == null) {
+          "struct<>"
+        } else {
+          s"${row.schema.catalogString}"
+        }).getOrElse("struct<>")
+
+    s"""
+       |== Results ==
+       |${
+      sideBySide(
+        s"== Correct Answer - ${expectedAnswer.size} ==" +:
+          getRowType(expectedAnswer.headOption) +:
+          prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
+        s"== Spark Answer - ${sparkAnswer.size} ==" +:
+          getRowType(sparkAnswer.headOption) +:
+          prepareAnswer(sparkAnswer, 
isSorted).map(_.toString())).mkString("\n")
+    }
+    """.stripMargin
+  }
+
+  def includesRows(
+      expectedRows: Seq[Row],
+      sparkAnswer: Seq[Row]): Option[String] = {
+    if (!prepareAnswer(expectedRows, 
true).toSet.subsetOf(prepareAnswer(sparkAnswer, true).toSet)) {
+      return Some(genError(expectedRows, sparkAnswer, true))
+    }
+    None
+  }
+
   def sameRows(
       expectedAnswer: Seq[Row],
       sparkAnswer: Seq[Row],
       isSorted: Boolean = false): Option[String] = {
     if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, 
isSorted)) {
-      val getRowType: Option[Row] => String = row =>
-        row.map(row =>
-            if (row.schema == null) {
-              "struct<>"
-            } else {
-                s"${row.schema.catalogString}"
-            }).getOrElse("struct<>")
-
-      val errorMessage =
-        s"""
-         |== Results ==
-         |${sideBySide(
-        s"== Correct Answer - ${expectedAnswer.size} ==" +:
-         getRowType(expectedAnswer.headOption) +:
-         prepareAnswer(expectedAnswer, isSorted).map(_.toString()),
-        s"== Spark Answer - ${sparkAnswer.size} ==" +:
-         getRowType(sparkAnswer.headOption) +:
-         prepareAnswer(sparkAnswer, 
isSorted).map(_.toString())).mkString("\n")}
-        """.stripMargin
-      return Some(errorMessage)
+      return Some(genError(expectedAnswer, sparkAnswer, isSorted))
     }
     None
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
index 6514c5f..dc833b2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/RateSourceV2Suite.scala
@@ -29,16 +29,6 @@ import 
org.apache.spark.sql.sources.v2.{ContinuousReadSupport, DataSourceV2Optio
 import org.apache.spark.sql.streaming.StreamTest
 
 class RateSourceV2Suite extends StreamTest {
-  test("microbatch in registry") {
-    DataSource.lookupDataSource("rate", spark.sqlContext.conf).newInstance() 
match {
-      case ds: MicroBatchReadSupport =>
-        val reader = ds.createMicroBatchReader(Optional.empty(), "", 
DataSourceV2Options.empty())
-        assert(reader.isInstanceOf[RateStreamV2Reader])
-      case _ =>
-        throw new IllegalStateException("Could not find v2 read support for 
rate")
-    }
-  }
-
   test("microbatch - numPartitions propagated") {
     val reader = new RateStreamV2Reader(
       new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> 
"33").asJava))
@@ -49,8 +39,8 @@ class RateSourceV2Suite extends StreamTest {
 
   test("microbatch - set offset") {
     val reader = new RateStreamV2Reader(DataSourceV2Options.empty())
-    val startOffset = RateStreamOffset(Map((0, (0, 1000))))
-    val endOffset = RateStreamOffset(Map((0, (0, 2000))))
+    val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
+    val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 2000))))
     reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
     assert(reader.getStartOffset() == startOffset)
     assert(reader.getEndOffset() == endOffset)
@@ -63,15 +53,15 @@ class RateSourceV2Suite extends StreamTest {
     reader.setOffsetRange(Optional.empty(), Optional.empty())
     reader.getStartOffset() match {
       case r: RateStreamOffset =>
-        assert(r.partitionToValueAndRunTimeMs(0)._2 == reader.creationTimeMs)
+        assert(r.partitionToValueAndRunTimeMs(0).runTimeMs == 
reader.creationTimeMs)
       case _ => throw new IllegalStateException("unexpected offset type")
     }
     reader.getEndOffset() match {
       case r: RateStreamOffset =>
         // End offset may be a bit beyond 100 ms/9 rows after creation if the 
wait lasted
         // longer than 100ms. It should never be early.
-        assert(r.partitionToValueAndRunTimeMs(0)._1 >= 9)
-        assert(r.partitionToValueAndRunTimeMs(0)._2 >= reader.creationTimeMs + 
100)
+        assert(r.partitionToValueAndRunTimeMs(0).value >= 9)
+        assert(r.partitionToValueAndRunTimeMs(0).runTimeMs >= 
reader.creationTimeMs + 100)
 
       case _ => throw new IllegalStateException("unexpected offset type")
     }
@@ -80,8 +70,8 @@ class RateSourceV2Suite extends StreamTest {
   test("microbatch - predetermined batch size") {
     val reader = new RateStreamV2Reader(
       new DataSourceV2Options(Map("numPartitions" -> "1", "rowsPerSecond" -> 
"20").asJava))
-    val startOffset = RateStreamOffset(Map((0, (0, 1000))))
-    val endOffset = RateStreamOffset(Map((0, (20, 2000))))
+    val startOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(0, 1000))))
+    val endOffset = RateStreamOffset(Map((0, ValueRunTimeMsPair(20, 2000))))
     reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
     val tasks = reader.createReadTasks()
     assert(tasks.size == 1)
@@ -93,8 +83,8 @@ class RateSourceV2Suite extends StreamTest {
       new DataSourceV2Options(Map("numPartitions" -> "11", "rowsPerSecond" -> 
"33").asJava))
     val startOffset = RateStreamSourceV2.createInitialOffset(11, 
reader.creationTimeMs)
     val endOffset = 
RateStreamOffset(startOffset.partitionToValueAndRunTimeMs.toSeq.map {
-      case (part, (currentVal, currentReadTime)) =>
-        (part, (currentVal + 33, currentReadTime + 1000))
+      case (part, ValueRunTimeMsPair(currentVal, currentReadTime)) =>
+        (part, ValueRunTimeMsPair(currentVal + 33, currentReadTime + 1000))
     }.toMap)
 
     reader.setOffsetRange(Optional.of(startOffset), Optional.of(endOffset))
@@ -135,7 +125,7 @@ class RateSourceV2Suite extends StreamTest {
         val startTimeMs = reader.getStartOffset()
           .asInstanceOf[RateStreamOffset]
           .partitionToValueAndRunTimeMs(t.partitionIndex)
-          ._2
+          .runTimeMs
         val r = t.createDataReader().asInstanceOf[RateStreamDataReader]
         for (rowIndex <- 0 to 9) {
           r.next()

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index 7554903..c65e5d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -77,10 +77,23 @@ class StreamSuite extends StreamTest {
   }
 
   test("StreamingRelation.computeStats") {
+    withTempDir { dir =>
+      val df = 
spark.readStream.format("csv").schema(StructType(Seq())).load(dir.getCanonicalPath)
+      val streamingRelation = df.logicalPlan collect {
+        case s: StreamingRelation => s
+      }
+      assert(streamingRelation.nonEmpty, "cannot find StreamingRelation")
+      assert(
+        streamingRelation.head.computeStats.sizeInBytes ==
+          spark.sessionState.conf.defaultSizeInBytes)
+    }
+  }
+
+  test("StreamingRelationV2.computeStats") {
     val streamingRelation = spark.readStream.format("rate").load().logicalPlan 
collect {
-      case s: StreamingRelation => s
+      case s: StreamingRelationV2 => s
     }
-    assert(streamingRelation.nonEmpty, "cannot find StreamingRelation")
+    assert(streamingRelation.nonEmpty, "cannot find 
StreamingExecutionRelation")
     assert(
       streamingRelation.head.computeStats.sizeInBytes == 
spark.sessionState.conf.defaultSizeInBytes)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 71a474e..fb9ebc8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -33,11 +33,14 @@ import 
org.scalatest.exceptions.TestFailedDueToTimeoutException
 import org.scalatest.time.Span
 import org.scalatest.time.SpanSugar._
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, 
RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.execution.streaming._
+import 
org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, 
EpochCoordinatorRef, IncrementAndGetEpoch}
+import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
 import org.apache.spark.sql.test.SharedSQLContext
@@ -168,6 +171,12 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     private def operatorName = if (lastOnly) "CheckLastBatch" else 
"CheckAnswer"
   }
 
+  case class CheckAnswerRowsContains(expectedAnswer: Seq[Row], lastOnly: 
Boolean = false)
+    extends StreamAction with StreamMustBeRunning {
+    override def toString: String = s"$operatorName: 
${expectedAnswer.mkString(",")}"
+    private def operatorName = if (lastOnly) "CheckLastBatch" else 
"CheckAnswer"
+  }
+
   case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: 
Boolean)
       extends StreamAction with StreamMustBeRunning {
     override def toString: String = s"$operatorName: 
${checkFunction.toString()}"
@@ -237,6 +246,25 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
       AssertOnQuery(query => { func(query); true })
   }
 
+  object AwaitEpoch {
+    def apply(epoch: Long): AssertOnQuery =
+      Execute {
+        case s: ContinuousExecution => s.awaitEpoch(epoch)
+        case _ => throw new IllegalStateException("microbatch cannot await 
epoch")
+      }
+  }
+
+  object IncrementEpoch {
+    def apply(): AssertOnQuery =
+      Execute {
+        case s: ContinuousExecution =>
+          val newEpoch = EpochCoordinatorRef.get(s.runId.toString, 
SparkEnv.get)
+            .askSync[Long](IncrementAndGetEpoch)
+          s.awaitEpoch(newEpoch - 1)
+        case _ => throw new IllegalStateException("microbatch cannot increment 
epoch")
+      }
+  }
+
   /**
    * Executes the specified actions on the given streaming DataFrame and 
provides helpful
    * error messages in the case of failures or incorrect answers.
@@ -246,7 +274,8 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
    */
   def testStream(
       _stream: Dataset[_],
-      outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): 
Unit = synchronized {
+      outputMode: OutputMode = OutputMode.Append,
+      useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized 
{
     import org.apache.spark.sql.streaming.util.StreamManualClock
 
     // `synchronized` is added to prevent the user from calling multiple 
`testStream`s concurrently
@@ -259,7 +288,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     var currentStream: StreamExecution = null
     var lastStream: StreamExecution = null
     val awaiting = new mutable.HashMap[Int, Offset]() // source index -> 
offset to wait for
-    val sink = new MemorySink(stream.schema, outputMode)
+    val sink = if (useV2Sink) new MemorySinkV2 else new 
MemorySink(stream.schema, outputMode)
     val resetConfValues = mutable.Map[String, Option[String]]()
 
     @volatile
@@ -308,7 +337,11 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
         ""
       }
 
-    def testState =
+    def testState = {
+      val sinkDebugString = sink match {
+        case s: MemorySink => s.toDebugString
+        case s: MemorySinkV2 => s.toDebugString
+      }
       s"""
          |== Progress ==
          |$testActions
@@ -321,12 +354,13 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
          |${if (streamThreadDeathCause != null) 
stackTraceToString(streamThreadDeathCause) else ""}
          |
          |== Sink ==
-         |${sink.toDebugString}
+         |$sinkDebugString
          |
          |
          |== Plan ==
          |${if (currentStream != null) currentStream.lastExecution else ""}
          """.stripMargin
+    }
 
     def verify(condition: => Boolean, message: String): Unit = {
       if (!condition) {
@@ -383,7 +417,11 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
         }
       }
 
-      try if (lastOnly) sink.latestBatchData else sink.allData catch {
+      val (latestBatchData, allData) = sink match {
+        case s: MemorySink => (s.latestBatchData, s.allData)
+        case s: MemorySinkV2 => (s.latestBatchData, s.allData)
+      }
+      try if (lastOnly) latestBatchData else allData catch {
         case e: Exception =>
           failTest("Exception while getting data from sink", e)
       }
@@ -423,6 +461,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
                   None,
                   Some(metadataRoot),
                   stream,
+                  Map(),
                   sink,
                   outputMode,
                   trigger = trigger,
@@ -594,6 +633,12 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
               error => failTest(error)
             }
 
+          case CheckAnswerRowsContains(expectedAnswer, lastOnly) =>
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+            QueryTest.includesRows(expectedAnswer, sparkAnswer).foreach {
+              error => failTest(error)
+            }
+
           case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
             val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
             sparkAnswer.foreach { row =>

http://git-wip-us.apache.org/repos/asf/spark/blob/8941a4ab/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
new file mode 100644
index 0000000..eda0d8a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.continuous
+
+import java.io.{File, InterruptedIOException, IOException, 
UncheckedIOException}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.{CountDownLatch, ExecutionException, 
TimeoutException, TimeUnit}
+
+import scala.reflect.ClassTag
+import scala.util.control.ControlThrowable
+
+import com.google.common.util.concurrent.UncheckedExecutionException
+import org.apache.commons.io.FileUtils
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.{SparkContext, SparkEnv}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.plans.logical.Range
+import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
+import org.apache.spark.sql.execution.command.ExplainCommand
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanExec, 
WriteToDataSourceV2Exec}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
+import org.apache.spark.sql.execution.streaming.state.{StateStore, 
StateStoreConf, StateStoreId, StateStoreProvider}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.sources.StreamSourceProvider
+import org.apache.spark.sql.streaming.{StreamTest, Trigger}
+import org.apache.spark.sql.streaming.util.StreamManualClock
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
+
+class ContinuousSuiteBase extends StreamTest {
+  // We need more than the default local[2] to be able to schedule all 
partitions simultaneously.
+  override protected def createSparkSession = new TestSparkSession(
+    new SparkContext(
+      "local[10]",
+      "continuous-stream-test-sql-context",
+      sparkConf.set("spark.sql.testkey", "true")))
+
+  protected def waitForRateSourceTriggers(query: StreamExecution, numTriggers: 
Int): Unit = {
+    query match {
+      case s: ContinuousExecution =>
+        assert(numTriggers >= 2, "must wait for at least 2 triggers to ensure 
query is initialized")
+        val reader = s.lastExecution.executedPlan.collectFirst {
+          case DataSourceV2ScanExec(_, r: ContinuousRateStreamReader) => r
+        }.get
+
+        val deltaMs = numTriggers * 1000 + 300
+        while (System.currentTimeMillis < reader.creationTime + deltaMs) {
+          Thread.sleep(reader.creationTime + deltaMs - 
System.currentTimeMillis)
+        }
+    }
+  }
+
+  // A continuous trigger that will only fire the initial time for the 
duration of a test.
+  // This allows clean testing with manual epoch advancement.
+  protected val longContinuousTrigger = Trigger.Continuous("1 hour")
+}
+
+class ContinuousSuite extends ContinuousSuiteBase {
+  import testImplicits._
+
+  test("basic rate source") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))),
+      StopStream,
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(2),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))),
+      StopStream)
+  }
+
+  test("map") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+      .map(r => r.getLong(0) * 2)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      Execute(waitForRateSourceTriggers(_, 4)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 40, 2).map(Row(_))))
+  }
+
+  test("flatMap") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+      .flatMap(r => Seq(0, r.getLong(0), r.getLong(0) * 2))
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      Execute(waitForRateSourceTriggers(_, 4)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 20).flatMap(n => Seq(0, n, n * 
2)).map(Row(_))))
+  }
+
+  test("filter") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+      .where('value > 5)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      Execute(waitForRateSourceTriggers(_, 4)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_))))
+  }
+
+  test("deduplicate") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+      .dropDuplicates()
+
+    val except = intercept[AnalysisException] {
+      testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger))
+    }
+
+    assert(except.message.contains(
+      "Continuous processing does not support Deduplicate operations."))
+  }
+
+  test("timestamp") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select(current_timestamp())
+
+    val except = intercept[AnalysisException] {
+      testStream(df, useV2Sink = true)(StartStream(longContinuousTrigger))
+    }
+
+    assert(except.message.contains(
+      "Continuous processing does not support current time operations."))
+  }
+
+  test("repeatedly restart") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "5")
+      .load()
+      .select('value)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))),
+      StopStream,
+      StartStream(longContinuousTrigger),
+      StopStream,
+      StartStream(longContinuousTrigger),
+      StopStream,
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(2),
+      Execute(waitForRateSourceTriggers(_, 2)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))),
+      StopStream)
+  }
+
+  test("query without test harness") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "2")
+      .option("rowsPerSecond", "2")
+      .load()
+      .select('value)
+    val query = df.writeStream
+      .format("memory")
+      .queryName("noharness")
+      .trigger(Trigger.Continuous(100))
+      .start()
+    val continuousExecution =
+      
query.asInstanceOf[StreamingQueryWrapper].streamingQuery.asInstanceOf[ContinuousExecution]
+    continuousExecution.awaitEpoch(0)
+    waitForRateSourceTriggers(continuousExecution, 2)
+    query.stop()
+
+    val results = spark.read.table("noharness").collect()
+    assert(Set(0, 1, 2, 3).map(Row(_)).subsetOf(results.toSet))
+  }
+}
+
+class ContinuousStressSuite extends ContinuousSuiteBase {
+  import testImplicits._
+
+  test("only one epoch") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "500")
+      .load()
+      .select('value)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(longContinuousTrigger),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 201)),
+      IncrementEpoch(),
+      Execute { query =>
+        val data = query.sink.asInstanceOf[MemorySinkV2].allData
+        val vals = data.map(_.getLong(0)).toSet
+        assert(scala.Range(0, 25000).forall { i =>
+          vals.contains(i)
+        })
+      })
+  }
+
+  test("automatic epoch advancement") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "500")
+      .load()
+      .select('value)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(0),
+      Execute(waitForRateSourceTriggers(_, 201)),
+      IncrementEpoch(),
+      CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
+  }
+
+  test("restarts") {
+    val df = spark.readStream
+      .format("rate")
+      .option("numPartitions", "5")
+      .option("rowsPerSecond", "500")
+      .load()
+      .select('value)
+
+    testStream(df, useV2Sink = true)(
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(10),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(20),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(21),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(22),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(25),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      StopStream,
+      StartStream(Trigger.Continuous(2012)),
+      AwaitEpoch(50),
+      CheckAnswerRowsContains(scala.Range(0, 25000).map(Row(_))))
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to