Repository: spark
Updated Branches:
  refs/heads/master 14844a62c -> 1cc66a072


[SPARK-23687][SS] Add a memory source for continuous processing.

## What changes were proposed in this pull request?

Add a memory source for continuous processing.

Note that only one of the ContinuousSuite tests is migrated to minimize the 
diff here. I'll submit a second PR for SPARK-23688 to change the rest and get 
rid of waitForRateSourceTriggers.

## How was this patch tested?

unit test

Author: Jose Torres <torres.joseph.f+git...@gmail.com>

Closes #20828 from jose-torres/continuousMemory.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1cc66a07
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1cc66a07
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1cc66a07

Branch: refs/heads/master
Commit: 1cc66a072b7fd3bf140fa41596f6b18f8d1bd7b9
Parents: 14844a6
Author: Jose Torres <torres.joseph.f+git...@gmail.com>
Authored: Tue Apr 17 01:59:38 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Tue Apr 17 01:59:38 2018 -0700

----------------------------------------------------------------------
 .../continuous/ContinuousExecution.scala        |   5 +-
 .../spark/sql/execution/streaming/memory.scala  |  59 ++++--
 .../sources/ContinuousMemoryStream.scala        | 211 +++++++++++++++++++
 .../apache/spark/sql/streaming/StreamTest.scala |   4 +-
 .../streaming/continuous/ContinuousSuite.scala  |  31 ++-
 5 files changed, 266 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1cc66a07/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 1758b38..951d694 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SQLExecution
 import 
org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, 
WriteToDataSourceV2}
 import org.apache.spark.sql.execution.streaming.{ContinuousExecutionRelation, 
StreamingRelationV2, _}
+import org.apache.spark.sql.sources.v2
 import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceOptions, StreamWriteSupport}
 import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousReader, 
PartitionOffset}
 import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger}
@@ -317,8 +318,10 @@ class ContinuousExecution(
     synchronized {
       if (queryExecutionThread.isAlive) {
         commitLog.add(epoch)
-        val offset = offsetLog.get(epoch).get.offsets(0).get
+        val offset =
+          
continuousSources(0).deserializeOffset(offsetLog.get(epoch).get.offsets(0).get.json)
         committedOffsets ++= Seq(continuousSources(0) -> offset)
+        
continuousSources(0).commit(offset.asInstanceOf[v2.reader.streaming.Offset])
       } else {
         return
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc66a07/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 352d4ce..628923d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -24,17 +24,19 @@ import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
-import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, 
Statistics}
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
+import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, 
SupportsScanUnsafeRow}
 import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, 
Offset => OffsetV2}
-import org.apache.spark.sql.streaming.OutputMode
+import org.apache.spark.sql.streaming.{OutputMode, Trigger}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.util.Utils
 
@@ -48,15 +50,42 @@ object MemoryStream {
 }
 
 /**
+ * A base class for memory stream implementations. Supports adding data and 
resetting.
+ */
+abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends 
BaseStreamingSource {
+  protected val encoder = encoderFor[A]
+  protected val attributes = encoder.schema.toAttributes
+
+  def toDS(): Dataset[A] = {
+    Dataset[A](sqlContext.sparkSession, logicalPlan)
+  }
+
+  def toDF(): DataFrame = {
+    Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
+  }
+
+  def addData(data: A*): Offset = {
+    addData(data.toTraversable)
+  }
+
+  def readSchema(): StructType = encoder.schema
+
+  protected def logicalPlan: LogicalPlan
+
+  def addData(data: TraversableOnce[A]): Offset
+}
+
+/**
  * A [[Source]] that produces value stored in memory as they are added by the 
user.  This [[Source]]
  * is intended for use in unit tests as it can only replay data when the 
object is still
  * available.
  */
 case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
-    extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
-  protected val encoder = encoderFor[A]
-  private val attributes = encoder.schema.toAttributes
-  protected val logicalPlan = StreamingExecutionRelation(this, 
attributes)(sqlContext.sparkSession)
+    extends MemoryStreamBase[A](sqlContext)
+      with MicroBatchReader with SupportsScanUnsafeRow with Logging {
+
+  protected val logicalPlan: LogicalPlan =
+    StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
   protected val output = logicalPlan.output
 
   /**
@@ -70,7 +99,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
   protected var currentOffset: LongOffset = new LongOffset(-1)
 
   @GuardedBy("this")
-  private var startOffset = new LongOffset(-1)
+  protected var startOffset = new LongOffset(-1)
 
   @GuardedBy("this")
   private var endOffset = new LongOffset(-1)
@@ -82,18 +111,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
   @GuardedBy("this")
   protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)
 
-  def toDS(): Dataset[A] = {
-    Dataset(sqlContext.sparkSession, logicalPlan)
-  }
-
-  def toDF(): DataFrame = {
-    Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
-  }
-
-  def addData(data: A*): Offset = {
-    addData(data.toTraversable)
-  }
-
   def addData(data: TraversableOnce[A]): Offset = {
     val objects = data.toSeq
     val rows = objects.iterator.map(d => 
encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
@@ -114,8 +131,6 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
     }
   }
 
-  override def readSchema(): StructType = encoder.schema
-
   override def deserializeOffset(json: String): OffsetV2 = 
LongOffset(json.toLong)
 
   override def getStartOffset: OffsetV2 = synchronized {

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc66a07/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
new file mode 100644
index 0000000..c28919b
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.sources
+
+import java.{util => ju}
+import java.util.Optional
+import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ListBuffer
+
+import org.json4s.NoTypeHints
+import org.json4s.jackson.Serialization
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.{Encoder, Row, SQLContext}
+import org.apache.spark.sql.execution.streaming._
+import 
org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream.GetRecord
+import org.apache.spark.sql.sources.v2.{ContinuousReadSupport, 
DataSourceOptions}
+import org.apache.spark.sql.sources.v2.reader.DataReaderFactory
+import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousDataReader, 
ContinuousReader, Offset, PartitionOffset}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.RpcUtils
+
+/**
+ * The overall strategy here is:
+ *  * ContinuousMemoryStream maintains a list of records for each partition. 
addData() will
+ *    distribute records evenly-ish across partitions.
+ *  * RecordEndpoint is set up as an endpoint for executor-side
+ *    ContinuousMemoryStreamDataReader instances to poll. It returns the 
record at the specified
+ *    offset within the list, or null if that offset doesn't yet have a record.
+ */
+class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
+  extends MemoryStreamBase[A](sqlContext) with ContinuousReader with 
ContinuousReadSupport {
+  private implicit val formats = Serialization.formats(NoTypeHints)
+  private val NUM_PARTITIONS = 2
+
+  protected val logicalPlan =
+    StreamingRelationV2(this, "memory", Map(), attributes, 
None)(sqlContext.sparkSession)
+
+  // ContinuousReader implementation
+
+  @GuardedBy("this")
+  private val records = Seq.fill(NUM_PARTITIONS)(new ListBuffer[A])
+
+  @GuardedBy("this")
+  private var startOffset: ContinuousMemoryStreamOffset = _
+
+  private val recordEndpoint = new RecordEndpoint()
+  @volatile private var endpointRef: RpcEndpointRef = _
+
+  def addData(data: TraversableOnce[A]): Offset = synchronized {
+    // Distribute data evenly among partition lists.
+    data.toSeq.zipWithIndex.map {
+      case (item, index) => records(index % NUM_PARTITIONS) += item
+    }
+
+    // The new target offset is the offset where all records in all partitions 
have been processed.
+    ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 
records(i).size)).toMap)
+  }
+
+  override def setStartOffset(start: Optional[Offset]): Unit = synchronized {
+    // Inferred initial offset is position 0 in each partition.
+    startOffset = start.orElse {
+      ContinuousMemoryStreamOffset((0 until NUM_PARTITIONS).map(i => (i, 
0)).toMap)
+    }.asInstanceOf[ContinuousMemoryStreamOffset]
+  }
+
+  override def getStartOffset: Offset = synchronized {
+    startOffset
+  }
+
+  override def deserializeOffset(json: String): ContinuousMemoryStreamOffset = 
{
+    ContinuousMemoryStreamOffset(Serialization.read[Map[Int, Int]](json))
+  }
+
+  override def mergeOffsets(offsets: Array[PartitionOffset]): 
ContinuousMemoryStreamOffset = {
+    ContinuousMemoryStreamOffset(
+      offsets.map {
+        case ContinuousMemoryStreamPartitionOffset(part, num) => (part, num)
+      }.toMap
+    )
+  }
+
+  override def createDataReaderFactories(): ju.List[DataReaderFactory[Row]] = {
+    synchronized {
+      val endpointName = 
s"ContinuousMemoryStreamRecordEndpoint-${java.util.UUID.randomUUID()}-$id"
+      endpointRef =
+        recordEndpoint.rpcEnv.setupEndpoint(endpointName, recordEndpoint)
+
+      startOffset.partitionNums.map {
+        case (part, index) =>
+          new ContinuousMemoryStreamDataReaderFactory(
+            endpointName, part, index): DataReaderFactory[Row]
+      }.toList.asJava
+    }
+  }
+
+  override def stop(): Unit = {
+    if (endpointRef != null) recordEndpoint.rpcEnv.stop(endpointRef)
+  }
+
+  override def commit(end: Offset): Unit = {}
+
+  // ContinuousReadSupport implementation
+  // This is necessary because of how StreamTest finds the source for 
AddDataMemory steps.
+  def createContinuousReader(
+      schema: Optional[StructType],
+      checkpointLocation: String,
+      options: DataSourceOptions): ContinuousReader = {
+    this
+  }
+
+  /**
+   * Endpoint for executors to poll for records.
+   */
+  private class RecordEndpoint extends ThreadSafeRpcEndpoint {
+    override val rpcEnv: RpcEnv = SparkEnv.get.rpcEnv
+
+    override def receiveAndReply(context: RpcCallContext): 
PartialFunction[Any, Unit] = {
+      case GetRecord(ContinuousMemoryStreamPartitionOffset(part, index)) =>
+        ContinuousMemoryStream.this.synchronized {
+          val buf = records(part)
+          val record = if (buf.size <= index) None else Some(buf(index))
+
+          context.reply(record.map(Row(_)))
+        }
+    }
+  }
+}
+
+object ContinuousMemoryStream {
+  case class GetRecord(offset: ContinuousMemoryStreamPartitionOffset)
+  protected val memoryStreamId = new AtomicInteger(0)
+
+  def apply[A : Encoder](implicit sqlContext: SQLContext): 
ContinuousMemoryStream[A] =
+    new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
+}
+
+/**
+ * Data reader factory for continuous memory stream.
+ */
+class ContinuousMemoryStreamDataReaderFactory(
+    driverEndpointName: String,
+    partition: Int,
+    startOffset: Int) extends DataReaderFactory[Row] {
+  override def createDataReader: ContinuousMemoryStreamDataReader =
+    new ContinuousMemoryStreamDataReader(driverEndpointName, partition, 
startOffset)
+}
+
+/**
+ * Data reader for continuous memory stream.
+ *
+ * Polls the driver endpoint for new records.
+ */
+class ContinuousMemoryStreamDataReader(
+    driverEndpointName: String,
+    partition: Int,
+    startOffset: Int) extends ContinuousDataReader[Row] {
+  private val endpoint = RpcUtils.makeDriverRef(
+    driverEndpointName,
+    SparkEnv.get.conf,
+    SparkEnv.get.rpcEnv)
+
+  private var currentOffset = startOffset
+  private var current: Option[Row] = None
+
+  override def next(): Boolean = {
+    current = None
+    while (current.isEmpty) {
+      Thread.sleep(10)
+      current = endpoint.askSync[Option[Row]](
+          GetRecord(ContinuousMemoryStreamPartitionOffset(partition, 
currentOffset)))
+    }
+    currentOffset += 1
+    true
+  }
+
+  override def get(): Row = current.get
+
+  override def close(): Unit = {}
+
+  override def getOffset: ContinuousMemoryStreamPartitionOffset =
+    ContinuousMemoryStreamPartitionOffset(partition, currentOffset)
+}
+
+case class ContinuousMemoryStreamOffset(partitionNums: Map[Int, Int])
+  extends Offset {
+  private implicit val formats = Serialization.formats(NoTypeHints)
+  override def json(): String = Serialization.write(partitionNums)
+}
+
+case class ContinuousMemoryStreamPartitionOffset(partition: Int, numProcessed: 
Int)
+  extends PartitionOffset

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc66a07/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index 00741d6..af0268f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -99,7 +99,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with 
TimeLimits with Be
    * been processed.
    */
   object AddData {
-    def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] =
+    def apply[A](source: MemoryStreamBase[A], data: A*): AddDataMemory[A] =
       AddDataMemory(source, data)
   }
 
@@ -131,7 +131,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     def runAction(): Unit
   }
 
-  case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends 
AddData {
+  case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) 
extends AddData {
     override def toString: String = s"AddData to $source: 
${data.mkString(",")}"
 
     override def addData(query: Option[StreamExecution]): 
(BaseStreamingSource, Offset) = {

http://git-wip-us.apache.org/repos/asf/spark/blob/1cc66a07/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
index ef74efe..c318b95 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql._
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous._
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.streaming.{StreamTest, Trigger}
 import org.apache.spark.sql.test.TestSparkSession
@@ -53,32 +54,24 @@ class ContinuousSuiteBase extends StreamTest {
   // A continuous trigger that will only fire the initial time for the 
duration of a test.
   // This allows clean testing with manual epoch advancement.
   protected val longContinuousTrigger = Trigger.Continuous("1 hour")
+
+  override protected val defaultTrigger = Trigger.Continuous(100)
+  override protected val defaultUseV2Sink = true
 }
 
 class ContinuousSuite extends ContinuousSuiteBase {
   import testImplicits._
 
-  test("basic rate source") {
-    val df = spark.readStream
-      .format("rate")
-      .option("numPartitions", "5")
-      .option("rowsPerSecond", "5")
-      .load()
-      .select('value)
+  test("basic") {
+    val input = ContinuousMemoryStream[Int]
 
-    testStream(df, useV2Sink = true)(
-      StartStream(longContinuousTrigger),
-      AwaitEpoch(0),
-      Execute(waitForRateSourceTriggers(_, 2)),
-      IncrementEpoch(),
-      CheckAnswerRowsContains(scala.Range(0, 10).map(Row(_))),
+    testStream(input.toDF())(
+      AddData(input, 0, 1, 2),
+      CheckAnswer(0, 1, 2),
       StopStream,
-      StartStream(longContinuousTrigger),
-      AwaitEpoch(2),
-      Execute(waitForRateSourceTriggers(_, 2)),
-      IncrementEpoch(),
-      CheckAnswerRowsContains(scala.Range(0, 20).map(Row(_))),
-      StopStream)
+      AddData(input, 3, 4, 5),
+      StartStream(),
+      CheckAnswer(0, 1, 2, 3, 4, 5))
   }
 
   test("map") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to