Repository: spark
Updated Branches:
  refs/heads/master 299d297e2 -> 1b46f41c5


[SPARK-24235][SS] Implement continuous shuffle writer for single reader 
partition.

## What changes were proposed in this pull request?

https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit

Implement continuous shuffle write RDD for a single reader partition. (I don't 
believe any implementation changes are actually required for multiple reader 
partitions, but this PR is already very large, so I want to exclude those for 
now to keep the size down.)

## How was this patch tested?

new unit tests

Author: Jose Torres <torres.joseph.f+git...@gmail.com>

Closes #21428 from jose-torres/writerTask.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1b46f41c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1b46f41c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1b46f41c

Branch: refs/heads/master
Commit: 1b46f41c55f5cd29956e17d7da95a95580cf273f
Parents: 299d297
Author: Jose Torres <torres.joseph.f+git...@gmail.com>
Authored: Wed Jun 13 13:13:01 2018 -0700
Committer: Shixiong Zhu <zsxw...@gmail.com>
Committed: Wed Jun 13 13:13:01 2018 -0700

----------------------------------------------------------------------
 .../shuffle/ContinuousShuffleReadRDD.scala      |   6 +-
 .../shuffle/ContinuousShuffleWriter.scala       |  27 ++
 .../shuffle/RPCContinuousShuffleReader.scala    | 138 ++++++
 .../shuffle/RPCContinuousShuffleWriter.scala    |  60 +++
 .../continuous/shuffle/UnsafeRowReceiver.scala  | 138 ------
 .../shuffle/ContinuousShuffleReadSuite.scala    | 291 -------------
 .../shuffle/ContinuousShuffleSuite.scala        | 416 +++++++++++++++++++
 7 files changed, 645 insertions(+), 431 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
index 801b28b..cf6572d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -34,8 +34,10 @@ case class ContinuousShuffleReadPartition(
   // Initialized only on the executor, and only once even as we call compute() 
multiple times.
   lazy val (reader: ContinuousShuffleReader, endpoint) = {
     val env = SparkEnv.get.rpcEnv
-    val receiver = new UnsafeRowReceiver(queueSize, numShuffleWriters, 
epochIntervalMs, env)
-    val endpoint = 
env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
+    val receiver = new RPCContinuousShuffleReader(
+      queueSize, numShuffleWriters, epochIntervalMs, env)
+    val endpoint = 
env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
+
     TaskContext.get().addTaskCompletionListener { ctx =>
       env.stop(endpoint)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
new file mode 100644
index 0000000..47b1f78
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleWriter.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for writing to a continuous processing shuffle.
+ */
+trait ContinuousShuffleWriter {
+  def write(epoch: Iterator[UnsafeRow]): Unit
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
new file mode 100644
index 0000000..834e846
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.concurrent._
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.NextIterator
+
+/**
+ * Messages for the RPCContinuousShuffleReader endpoint. Either an incoming 
row or an epoch marker.
+ *
+ * Each message comes tagged with writerId, identifying which writer the 
message is coming
+ * from. The receiver will only begin the next epoch once all writers have 
sent an epoch
+ * marker ending the current epoch.
+ */
+private[shuffle] sealed trait RPCContinuousShuffleMessage extends Serializable 
{
+  def writerId: Int
+}
+private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
+  extends RPCContinuousShuffleMessage
+private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends 
RPCContinuousShuffleMessage
+
+/**
+ * RPC endpoint for receiving rows into a continuous processing shuffle task. 
Continuous shuffle
+ * writers will send rows here, with continuous shuffle readers polling for 
new rows as needed.
+ *
+ * TODO: Support multiple source tasks. We need to output a single epoch 
marker once all
+ * source tasks have sent one.
+ */
+private[shuffle] class RPCContinuousShuffleReader(
+      queueSize: Int,
+      numShuffleWriters: Int,
+      epochIntervalMs: Long,
+      override val rpcEnv: RpcEnv)
+    extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
+  // Note that this queue will be drained from the main task thread and 
populated in the RPC
+  // response thread.
+  private val queues = Array.fill(numShuffleWriters) {
+    new ArrayBlockingQueue[RPCContinuousShuffleMessage](queueSize)
+  }
+
+  // Exposed for testing to determine if the endpoint gets stopped on task end.
+  private[shuffle] val stopped = new AtomicBoolean(false)
+
+  override def onStop(): Unit = {
+    stopped.set(true)
+  }
+
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
+    case r: RPCContinuousShuffleMessage =>
+      // Note that this will block a thread the shared RPC handler pool!
+      // The TCP based shuffle handler (SPARK-24541) will avoid this problem.
+      queues(r.writerId).put(r)
+      context.reply(())
+  }
+
+  override def read(): Iterator[UnsafeRow] = {
+    new NextIterator[UnsafeRow] {
+      // An array of flags for whether each writer ID has gotten an epoch 
marker.
+      private val writerEpochMarkersReceived = 
Array.fill(numShuffleWriters)(false)
+
+      private val executor = Executors.newFixedThreadPool(numShuffleWriters)
+      private val completion = new 
ExecutorCompletionService[RPCContinuousShuffleMessage](executor)
+
+      private def completionTask(writerId: Int) = new 
Callable[RPCContinuousShuffleMessage] {
+        override def call(): RPCContinuousShuffleMessage = 
queues(writerId).take()
+      }
+
+      // Initialize by submitting tasks to read the first row from each writer.
+      (0 until numShuffleWriters).foreach(writerId => 
completion.submit(completionTask(writerId)))
+
+      /**
+       * In each call to getNext(), we pull the next row available in the 
completion queue, and then
+       * submit another task to read the next row from the writer which 
returned it.
+       *
+       * When a writer sends an epoch marker, we note that it's finished and 
don't submit another
+       * task for it in this epoch. The iterator is over once all writers have 
sent an epoch marker.
+       */
+      override def getNext(): UnsafeRow = {
+        var nextRow: UnsafeRow = null
+        while (!finished && nextRow == null) {
+          completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match {
+            case null =>
+              // Try again if the poll didn't wait long enough to get a real 
result.
+              // But we should be getting at least an epoch marker every 
checkpoint interval.
+              val writerIdsUncommitted = 
writerEpochMarkersReceived.zipWithIndex.collect {
+                case (flag, idx) if !flag => idx
+              }
+              logWarning(
+                s"Completion service failed to make progress after 
$epochIntervalMs ms. Waiting " +
+                  s"for writers $writerIdsUncommitted to send epoch markers.")
+
+            // The completion service guarantees this future will be available 
immediately.
+            case future => future.get() match {
+              case ReceiverRow(writerId, r) =>
+                // Start reading the next element in the queue we just took 
from.
+                completion.submit(completionTask(writerId))
+                nextRow = r
+              case ReceiverEpochMarker(writerId) =>
+                // Don't read any more from this queue. If all the writers 
have sent epoch markers,
+                // the epoch is over; otherwise we need to loop again to poll 
from the remaining
+                // writers.
+                writerEpochMarkersReceived(writerId) = true
+                if (writerEpochMarkersReceived.forall(_ == true)) {
+                  finished = true
+                }
+            }
+          }
+        }
+
+        nextRow
+      }
+
+      override def close(): Unit = {
+        executor.shutdownNow()
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
new file mode 100644
index 0000000..1c6f3dd
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleWriter.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import scala.concurrent.Future
+import scala.concurrent.duration.Duration
+
+import org.apache.spark.Partitioner
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[ContinuousShuffleWriter]] sending data to 
[[RPCContinuousShuffleReader]] instances.
+ *
+ * @param writerId The partition ID of this writer.
+ * @param outputPartitioner The partitioner on the reader side of the shuffle.
+ * @param endpoints The [[RPCContinuousShuffleReader]] endpoints to write to. 
Indexed by
+ *                  partition ID within outputPartitioner.
+ */
+class RPCContinuousShuffleWriter(
+    writerId: Int,
+    outputPartitioner: Partitioner,
+    endpoints: Array[RpcEndpointRef]) extends ContinuousShuffleWriter {
+
+  if (outputPartitioner.numPartitions != 1) {
+    throw new IllegalArgumentException("multiple readers not yet supported")
+  }
+
+  if (outputPartitioner.numPartitions != endpoints.length) {
+    throw new IllegalArgumentException(s"partitioner size 
${outputPartitioner.numPartitions} did " +
+      s"not match endpoint count ${endpoints.length}")
+  }
+
+  def write(epoch: Iterator[UnsafeRow]): Unit = {
+    while (epoch.hasNext) {
+      val row = epoch.next()
+      
endpoints(outputPartitioner.getPartition(row)).askSync[Unit](ReceiverRow(writerId,
 row))
+    }
+
+    val futures = 
endpoints.map(_.ask[Unit](ReceiverEpochMarker(writerId))).toSeq
+    implicit val ec = ThreadUtils.sameThread
+    ThreadUtils.awaitResult(Future.sequence(futures), Duration.Inf)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
deleted file mode 100644
index d81f552..0000000
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.continuous.shuffle
-
-import java.util.concurrent._
-import java.util.concurrent.atomic.AtomicBoolean
-
-import scala.collection.mutable
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
-import org.apache.spark.util.NextIterator
-
-/**
- * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an 
epoch marker.
- *
- * Each message comes tagged with writerId, identifying which writer the 
message is coming
- * from. The receiver will only begin the next epoch once all writers have 
sent an epoch
- * marker ending the current epoch.
- */
-private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable {
-  def writerId: Int
-}
-private[shuffle] case class ReceiverRow(writerId: Int, row: UnsafeRow)
-  extends UnsafeRowReceiverMessage
-private[shuffle] case class ReceiverEpochMarker(writerId: Int) extends 
UnsafeRowReceiverMessage
-
-/**
- * RPC endpoint for receiving rows into a continuous processing shuffle task. 
Continuous shuffle
- * writers will send rows here, with continuous shuffle readers polling for 
new rows as needed.
- *
- * TODO: Support multiple source tasks. We need to output a single epoch 
marker once all
- * source tasks have sent one.
- */
-private[shuffle] class UnsafeRowReceiver(
-      queueSize: Int,
-      numShuffleWriters: Int,
-      epochIntervalMs: Long,
-      override val rpcEnv: RpcEnv)
-    extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
-  // Note that this queue will be drained from the main task thread and 
populated in the RPC
-  // response thread.
-  private val queues = Array.fill(numShuffleWriters) {
-    new ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
-  }
-
-  // Exposed for testing to determine if the endpoint gets stopped on task end.
-  private[shuffle] val stopped = new AtomicBoolean(false)
-
-  override def onStop(): Unit = {
-    stopped.set(true)
-  }
-
-  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
-    case r: UnsafeRowReceiverMessage =>
-      queues(r.writerId).put(r)
-      context.reply(())
-  }
-
-  override def read(): Iterator[UnsafeRow] = {
-    new NextIterator[UnsafeRow] {
-      // An array of flags for whether each writer ID has gotten an epoch 
marker.
-      private val writerEpochMarkersReceived = 
Array.fill(numShuffleWriters)(false)
-
-      private val executor = Executors.newFixedThreadPool(numShuffleWriters)
-      private val completion = new 
ExecutorCompletionService[UnsafeRowReceiverMessage](executor)
-
-      private def completionTask(writerId: Int) = new 
Callable[UnsafeRowReceiverMessage] {
-        override def call(): UnsafeRowReceiverMessage = queues(writerId).take()
-      }
-
-      // Initialize by submitting tasks to read the first row from each writer.
-      (0 until numShuffleWriters).foreach(writerId => 
completion.submit(completionTask(writerId)))
-
-      /**
-       * In each call to getNext(), we pull the next row available in the 
completion queue, and then
-       * submit another task to read the next row from the writer which 
returned it.
-       *
-       * When a writer sends an epoch marker, we note that it's finished and 
don't submit another
-       * task for it in this epoch. The iterator is over once all writers have 
sent an epoch marker.
-       */
-      override def getNext(): UnsafeRow = {
-        var nextRow: UnsafeRow = null
-        while (!finished && nextRow == null) {
-          completion.poll(epochIntervalMs, TimeUnit.MILLISECONDS) match {
-            case null =>
-              // Try again if the poll didn't wait long enough to get a real 
result.
-              // But we should be getting at least an epoch marker every 
checkpoint interval.
-              val writerIdsUncommitted = 
writerEpochMarkersReceived.zipWithIndex.collect {
-                case (flag, idx) if !flag => idx
-              }
-              logWarning(
-                s"Completion service failed to make progress after 
$epochIntervalMs ms. Waiting " +
-                  s"for writers $writerIdsUncommitted to send epoch markers.")
-
-            // The completion service guarantees this future will be available 
immediately.
-            case future => future.get() match {
-              case ReceiverRow(writerId, r) =>
-                // Start reading the next element in the queue we just took 
from.
-                completion.submit(completionTask(writerId))
-                nextRow = r
-              case ReceiverEpochMarker(writerId) =>
-                // Don't read any more from this queue. If all the writers 
have sent epoch markers,
-                // the epoch is over; otherwise we need to loop again to poll 
from the remaining
-                // writers.
-                writerEpochMarkersReceived(writerId) = true
-                if (writerEpochMarkersReceived.forall(_ == true)) {
-                  finished = true
-                }
-            }
-          }
-        }
-
-        nextRow
-      }
-
-      override def close(): Unit = {
-        executor.shutdownNow()
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
deleted file mode 100644
index 2e4d607..0000000
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
+++ /dev/null
@@ -1,291 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.streaming.continuous.shuffle
-
-import org.apache.spark.{TaskContext, TaskContextImpl}
-import org.apache.spark.rpc.RpcEndpointRef
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
-import org.apache.spark.sql.streaming.StreamTest
-import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
-import org.apache.spark.unsafe.types.UTF8String
-
-class ContinuousShuffleReadSuite extends StreamTest {
-
-  private def unsafeRow(value: Int) = {
-    UnsafeProjection.create(Array(IntegerType : DataType))(
-      new GenericInternalRow(Array(value: Any)))
-  }
-
-  private def unsafeRow(value: String) = {
-    UnsafeProjection.create(Array(StringType : DataType))(
-      new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
-  }
-
-  private def send(endpoint: RpcEndpointRef, messages: 
UnsafeRowReceiverMessage*) = {
-    messages.foreach(endpoint.askSync[Unit](_))
-  }
-
-  // In this unit test, we emulate that we're in the task thread where
-  // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a 
task context
-  // thread local to be set.
-  var ctx: TaskContextImpl = _
-
-  override def beforeEach(): Unit = {
-    super.beforeEach()
-    ctx = TaskContext.empty()
-    TaskContext.setTaskContext(ctx)
-  }
-
-  override def afterEach(): Unit = {
-    ctx.markTaskCompleted(None)
-    TaskContext.unset()
-    ctx = null
-    super.afterEach()
-  }
-
-  test("receiver stopped with row last") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverEpochMarker(0),
-      ReceiverRow(0, unsafeRow(111))
-    )
-
-    ctx.markTaskCompleted(None)
-    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
-    eventually(timeout(streamingTimeout)) {
-      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
-    }
-  }
-
-  test("receiver stopped with marker last") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow(111)),
-      ReceiverEpochMarker(0)
-    )
-
-    ctx.markTaskCompleted(None)
-    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
-    eventually(timeout(streamingTimeout)) {
-      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
-    }
-  }
-
-  test("one epoch") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow(111)),
-      ReceiverRow(0, unsafeRow(222)),
-      ReceiverRow(0, unsafeRow(333)),
-      ReceiverEpochMarker(0)
-    )
-
-    val iter = rdd.compute(rdd.partitions(0), ctx)
-    assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
-  }
-
-  test("multiple epochs") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow(111)),
-      ReceiverEpochMarker(0),
-      ReceiverRow(0, unsafeRow(222)),
-      ReceiverRow(0, unsafeRow(333)),
-      ReceiverEpochMarker(0)
-    )
-
-    val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
-    assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111))
-
-    val secondEpoch = rdd.compute(rdd.partitions(0), ctx)
-    assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
-  }
-
-  test("empty epochs") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-
-    send(
-      endpoint,
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(0),
-      ReceiverRow(0, unsafeRow(111)),
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(0)
-    )
-
-    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
-    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
-
-    val thirdEpoch = rdd.compute(rdd.partitions(0), ctx)
-    assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111))
-
-    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
-    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
-  }
-
-  test("multiple partitions") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
-    // Send all data before processing to ensure there's no crossover.
-    for (p <- rdd.partitions) {
-      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
-      // Send index for identification.
-      send(
-        part.endpoint,
-        ReceiverRow(0, unsafeRow(part.index)),
-        ReceiverEpochMarker(0)
-      )
-    }
-
-    for (p <- rdd.partitions) {
-      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
-      val iter = rdd.compute(part, ctx)
-      assert(iter.next().getInt(0) == part.index)
-      assert(!iter.hasNext)
-    }
-  }
-
-  test("blocks waiting for new rows") {
-    val rdd = new ContinuousShuffleReadRDD(
-      sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
-    val epoch = rdd.compute(rdd.partitions(0), ctx)
-
-    val readRowThread = new Thread {
-      override def run(): Unit = {
-        try {
-          epoch.next().getInt(0)
-        } catch {
-          case _: InterruptedException => // do nothing - expected at test 
ending
-        }
-      }
-    }
-
-    try {
-      readRowThread.start()
-      eventually(timeout(streamingTimeout)) {
-        assert(readRowThread.getState == Thread.State.TIMED_WAITING)
-      }
-    } finally {
-      readRowThread.interrupt()
-      readRowThread.join()
-    }
-  }
-
-  test("multiple writers") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, 
numShuffleWriters = 3)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow("writer0-row0")),
-      ReceiverRow(1, unsafeRow("writer1-row0")),
-      ReceiverRow(2, unsafeRow("writer2-row0")),
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(1),
-      ReceiverEpochMarker(2)
-    )
-
-    val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
-    assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet ==
-      Set("writer0-row0", "writer1-row0", "writer2-row0"))
-  }
-
-  test("epoch only ends when all writers send markers") {
-    val rdd = new ContinuousShuffleReadRDD(
-      sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs 
= Long.MaxValue)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow("writer0-row0")),
-      ReceiverRow(1, unsafeRow("writer1-row0")),
-      ReceiverRow(2, unsafeRow("writer2-row0")),
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(2)
-    )
-
-    val epoch = rdd.compute(rdd.partitions(0), ctx)
-    val rows = (0 until 3).map(_ => epoch.next()).toSet
-    assert(rows.map(_.getUTF8String(0).toString) ==
-      Set("writer0-row0", "writer1-row0", "writer2-row0"))
-
-    // After checking the right rows, block until we get an epoch marker 
indicating there's no next.
-    // (Also fail the assertion if for some reason we get a row.)
-    val readEpochMarkerThread = new Thread {
-      override def run(): Unit = {
-        assert(!epoch.hasNext)
-      }
-    }
-
-    readEpochMarkerThread.start()
-    eventually(timeout(streamingTimeout)) {
-      assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
-    }
-
-    // Send the last epoch marker - now the epoch should finish.
-    send(endpoint, ReceiverEpochMarker(1))
-    eventually(timeout(streamingTimeout)) {
-      !readEpochMarkerThread.isAlive
-    }
-
-    // Join to pick up assertion failures.
-    readEpochMarkerThread.join()
-  }
-
-  test("writer epochs non aligned") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, 
numShuffleWriters = 3)
-    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
-    // We send multiple epochs for 0, then multiple for 1, then multiple for 
2. The receiver should
-    // collate them as though the markers were aligned in the first place.
-    send(
-      endpoint,
-      ReceiverRow(0, unsafeRow("writer0-row0")),
-      ReceiverEpochMarker(0),
-      ReceiverRow(0, unsafeRow("writer0-row1")),
-      ReceiverEpochMarker(0),
-      ReceiverEpochMarker(0),
-
-      ReceiverEpochMarker(1),
-      ReceiverRow(1, unsafeRow("writer1-row0")),
-      ReceiverEpochMarker(1),
-      ReceiverRow(1, unsafeRow("writer1-row1")),
-      ReceiverEpochMarker(1),
-
-      ReceiverEpochMarker(2),
-      ReceiverEpochMarker(2),
-      ReceiverRow(2, unsafeRow("writer2-row0")),
-      ReceiverEpochMarker(2)
-    )
-
-    val firstEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
-    assert(firstEpoch == Set("writer0-row0"))
-
-    val secondEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
-    assert(secondEpoch == Set("writer0-row1", "writer1-row0"))
-
-    val thirdEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
-    assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1b46f41c/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
new file mode 100644
index 0000000..a8e3611
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.{HashPartitioner, Partition, TaskContext, 
TaskContextImpl}
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class ContinuousShuffleSuite extends StreamTest {
+  // In this unit test, we emulate that we're in the task thread where
+  // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a 
task context
+  // thread local to be set.
+  var ctx: TaskContextImpl = _
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    ctx = TaskContext.empty()
+    TaskContext.setTaskContext(ctx)
+  }
+
+  override def afterEach(): Unit = {
+    ctx.markTaskCompleted(None)
+    TaskContext.unset()
+    ctx = null
+    super.afterEach()
+  }
+
+  private implicit def unsafeRow(value: Int) = {
+    UnsafeProjection.create(Array(IntegerType : DataType))(
+      new GenericInternalRow(Array(value: Any)))
+  }
+
+  private def unsafeRow(value: String) = {
+    UnsafeProjection.create(Array(StringType : DataType))(
+      new GenericInternalRow(Array(UTF8String.fromString(value): Any)))
+  }
+
+  private def send(endpoint: RpcEndpointRef, messages: 
RPCContinuousShuffleMessage*) = {
+    messages.foreach(endpoint.askSync[Unit](_))
+  }
+
+  private def readRDDEndpoint(rdd: ContinuousShuffleReadRDD) = {
+    rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+  }
+
+  private def readEpoch(rdd: ContinuousShuffleReadRDD) = {
+    rdd.compute(rdd.partitions(0), ctx).toSeq.map(_.getInt(0))
+  }
+
+  test("reader - one epoch") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow(111)),
+      ReceiverRow(0, unsafeRow(222)),
+      ReceiverRow(0, unsafeRow(333)),
+      ReceiverEpochMarker(0)
+    )
+
+    val iter = rdd.compute(rdd.partitions(0), ctx)
+    assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
+  }
+
+  test("reader - multiple epochs") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow(111)),
+      ReceiverEpochMarker(0),
+      ReceiverRow(0, unsafeRow(222)),
+      ReceiverRow(0, unsafeRow(333)),
+      ReceiverEpochMarker(0)
+    )
+
+    val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+    val secondEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
+  }
+
+  test("reader - empty epochs") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+
+    send(
+      endpoint,
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(0),
+      ReceiverRow(0, unsafeRow(111)),
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(0)
+    )
+
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+
+    val thirdEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+  }
+
+  test("reader - multiple partitions") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
+    // Send all data before processing to ensure there's no crossover.
+    for (p <- rdd.partitions) {
+      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+      // Send index for identification.
+      send(
+        part.endpoint,
+        ReceiverRow(0, unsafeRow(part.index)),
+        ReceiverEpochMarker(0)
+      )
+    }
+
+    for (p <- rdd.partitions) {
+      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+      val iter = rdd.compute(part, ctx)
+      assert(iter.next().getInt(0) == part.index)
+      assert(!iter.hasNext)
+    }
+  }
+
+  test("reader - blocks waiting for new rows") {
+    val rdd = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, epochIntervalMs = Long.MaxValue)
+    val epoch = rdd.compute(rdd.partitions(0), ctx)
+
+    val readRowThread = new Thread {
+      override def run(): Unit = {
+        try {
+          epoch.next().getInt(0)
+        } catch {
+          case _: InterruptedException => // do nothing - expected at test 
ending
+        }
+      }
+    }
+
+    try {
+      readRowThread.start()
+      eventually(timeout(streamingTimeout)) {
+        assert(readRowThread.getState == Thread.State.TIMED_WAITING)
+      }
+    } finally {
+      readRowThread.interrupt()
+      readRowThread.join()
+    }
+  }
+
+  test("reader - multiple writers") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, 
numShuffleWriters = 3)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow("writer0-row0")),
+      ReceiverRow(1, unsafeRow("writer1-row0")),
+      ReceiverRow(2, unsafeRow("writer2-row0")),
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(1),
+      ReceiverEpochMarker(2)
+    )
+
+    val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(firstEpoch.toSeq.map(_.getUTF8String(0).toString).toSet ==
+      Set("writer0-row0", "writer1-row0", "writer2-row0"))
+  }
+
+  test("reader - epoch only ends when all writers send markers") {
+    val rdd = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, numShuffleWriters = 3, epochIntervalMs 
= Long.MaxValue)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow("writer0-row0")),
+      ReceiverRow(1, unsafeRow("writer1-row0")),
+      ReceiverRow(2, unsafeRow("writer2-row0")),
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(2)
+    )
+
+    val epoch = rdd.compute(rdd.partitions(0), ctx)
+    val rows = (0 until 3).map(_ => epoch.next()).toSet
+    assert(rows.map(_.getUTF8String(0).toString) ==
+      Set("writer0-row0", "writer1-row0", "writer2-row0"))
+
+    // After checking the right rows, block until we get an epoch marker 
indicating there's no next.
+    // (Also fail the assertion if for some reason we get a row.)
+
+    val readEpochMarkerThread = new Thread {
+      override def run(): Unit = {
+        assert(!epoch.hasNext)
+      }
+    }
+
+    readEpochMarkerThread.start()
+    eventually(timeout(streamingTimeout)) {
+      assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
+    }
+
+    // Send the last epoch marker - now the epoch should finish.
+    send(endpoint, ReceiverEpochMarker(1))
+    eventually(timeout(streamingTimeout)) {
+      !readEpochMarkerThread.isAlive
+    }
+
+    // Join to pick up assertion failures.
+    readEpochMarkerThread.join(streamingTimeout.toMillis)
+  }
+
+  test("reader - writer epochs non aligned") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1, 
numShuffleWriters = 3)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    // We send multiple epochs for 0, then multiple for 1, then multiple for 
2. The receiver should
+    // collate them as though the markers were aligned in the first place.
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow("writer0-row0")),
+      ReceiverEpochMarker(0),
+      ReceiverRow(0, unsafeRow("writer0-row1")),
+      ReceiverEpochMarker(0),
+      ReceiverEpochMarker(0),
+
+      ReceiverEpochMarker(1),
+      ReceiverRow(1, unsafeRow("writer1-row0")),
+      ReceiverEpochMarker(1),
+      ReceiverRow(1, unsafeRow("writer1-row1")),
+      ReceiverEpochMarker(1),
+
+      ReceiverEpochMarker(2),
+      ReceiverEpochMarker(2),
+      ReceiverRow(2, unsafeRow("writer2-row0")),
+      ReceiverEpochMarker(2)
+    )
+
+    val firstEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
+    assert(firstEpoch == Set("writer0-row0"))
+
+    val secondEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
+    assert(secondEpoch == Set("writer0-row1", "writer1-row0"))
+
+    val thirdEpoch = rdd.compute(rdd.partitions(0), 
ctx).map(_.getUTF8String(0).toString).toSet
+    assert(thirdEpoch == Set("writer1-row1", "writer2-row0"))
+  }
+
+  test("one epoch") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator(1, 2, 3))
+
+    assert(readEpoch(reader) == Seq(1, 2, 3))
+  }
+
+  test("multiple epochs") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator(1, 2, 3))
+    writer.write(Iterator(4, 5, 6))
+
+    assert(readEpoch(reader) == Seq(1, 2, 3))
+    assert(readEpoch(reader) == Seq(4, 5, 6))
+  }
+
+  test("empty epochs") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    writer.write(Iterator())
+    writer.write(Iterator(1, 2))
+    writer.write(Iterator())
+    writer.write(Iterator())
+    writer.write(Iterator(3, 4))
+    writer.write(Iterator())
+
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq(1, 2))
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq())
+    assert(readEpoch(reader) == Seq(3, 4))
+    assert(readEpoch(reader) == Seq())
+  }
+
+  test("blocks waiting for writer") {
+    val reader = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val writer = new RPCContinuousShuffleWriter(
+      0, new HashPartitioner(1), Array(readRDDEndpoint(reader)))
+
+    val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+    val readRowThread = new Thread {
+      override def run(): Unit = {
+        assert(readerEpoch.toSeq.map(_.getInt(0)) == Seq(1))
+      }
+    }
+    readRowThread.start()
+
+    eventually(timeout(streamingTimeout)) {
+      assert(readRowThread.getState == Thread.State.TIMED_WAITING)
+    }
+
+    // Once we write the epoch the thread should stop waiting and succeed.
+    writer.write(Iterator(1))
+    readRowThread.join(streamingTimeout.toMillis)
+  }
+
+  test("multiple writer partitions") {
+    val numWriterPartitions = 3
+
+    val reader = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
+    val writers = (0 until 3).map { idx =>
+      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), 
Array(readRDDEndpoint(reader)))
+    }
+
+    writers(0).write(Iterator(1, 4, 7))
+    writers(1).write(Iterator(2, 5))
+    writers(2).write(Iterator(3, 6))
+
+    writers(0).write(Iterator(4, 7, 10))
+    writers(1).write(Iterator(5, 8))
+    writers(2).write(Iterator(6, 9))
+
+    // Since there are multiple asynchronous writers, the original row 
sequencing is not guaranteed.
+    // The epochs should be deterministically preserved, however.
+    assert(readEpoch(reader).toSet == Seq(1, 2, 3, 4, 5, 6, 7).toSet)
+    assert(readEpoch(reader).toSet == Seq(4, 5, 6, 7, 8, 9, 10).toSet)
+  }
+
+  test("reader epoch only ends when all writer partitions write it") {
+    val numWriterPartitions = 3
+
+    val reader = new ContinuousShuffleReadRDD(
+      sparkContext, numPartitions = 1, numShuffleWriters = numWriterPartitions)
+    val writers = (0 until 3).map { idx =>
+      new RPCContinuousShuffleWriter(idx, new HashPartitioner(1), 
Array(readRDDEndpoint(reader)))
+    }
+
+    writers(1).write(Iterator())
+    writers(2).write(Iterator())
+
+    val readerEpoch = reader.compute(reader.partitions(0), ctx)
+
+    val readEpochMarkerThread = new Thread {
+      override def run(): Unit = {
+        assert(!readerEpoch.hasNext)
+      }
+    }
+
+    readEpochMarkerThread.start()
+    eventually(timeout(streamingTimeout)) {
+      assert(readEpochMarkerThread.getState == Thread.State.TIMED_WAITING)
+    }
+
+    writers(0).write(Iterator())
+    readEpochMarkerThread.join(streamingTimeout.toMillis)
+  }
+
+  test("receiver stopped with row last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverEpochMarker(0),
+      ReceiverRow(0, unsafeRow(111))
+    )
+
+    ctx.markTaskCompleted(None)
+    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+    }
+  }
+
+  test("receiver stopped with marker last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(0, unsafeRow(111)),
+      ReceiverEpochMarker(0)
+    )
+
+    ctx.markTaskCompleted(None)
+    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[RPCContinuousShuffleReader].stopped.get())
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to