Repository: spark
Updated Branches:
  refs/heads/master 03e90f65b -> a33dcf4a0


[SPARK-24234][SS] Reader for continuous processing shuffle

## What changes were proposed in this pull request?

Read RDD for continuous processing shuffle, as well as the initial RPC-based 
row receiver.

https://docs.google.com/document/d/1IL4kJoKrZWeyIhklKUJqsW-yEN7V7aL05MmM65AYOfE/edit#heading=h.8t3ci57f7uii

## How was this patch tested?

new unit tests

Author: Jose Torres <torres.joseph.f+git...@gmail.com>

Closes #21337 from jose-torres/readerRddMaster.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a33dcf4a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a33dcf4a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a33dcf4a

Branch: refs/heads/master
Commit: a33dcf4a0bbe20dce6f1e1e6c2e1c3828291fb3d
Parents: 03e90f6
Author: Jose Torres <torres.joseph.f+git...@gmail.com>
Authored: Mon May 21 12:58:05 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Mon May 21 12:58:05 2018 -0700

----------------------------------------------------------------------
 .../shuffle/ContinuousShuffleReadRDD.scala      |  61 ++++++
 .../shuffle/ContinuousShuffleReader.scala       |  32 ++++
 .../continuous/shuffle/UnsafeRowReceiver.scala  |  75 ++++++++
 .../shuffle/ContinuousShuffleReadSuite.scala    | 184 +++++++++++++++++++
 4 files changed, 352 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
new file mode 100644
index 0000000..270b1a5
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.UUID
+
+import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.util.NextIterator
+
+case class ContinuousShuffleReadPartition(index: Int, queueSize: Int) extends 
Partition {
+  // Initialized only on the executor, and only once even as we call compute() 
multiple times.
+  lazy val (reader: ContinuousShuffleReader, endpoint) = {
+    val env = SparkEnv.get.rpcEnv
+    val receiver = new UnsafeRowReceiver(queueSize, env)
+    val endpoint = 
env.setupEndpoint(s"UnsafeRowReceiver-${UUID.randomUUID()}", receiver)
+    TaskContext.get().addTaskCompletionListener { ctx =>
+      env.stop(endpoint)
+    }
+    (receiver, endpoint)
+  }
+}
+
+/**
+ * RDD at the map side of each continuous processing shuffle task. Upstream 
tasks send their
+ * shuffle output to the wrapped receivers in partitions of this RDD; each of 
the RDD's tasks
+ * poll from their receiver until an epoch marker is sent.
+ */
+class ContinuousShuffleReadRDD(
+    sc: SparkContext,
+    numPartitions: Int,
+    queueSize: Int = 1024)
+  extends RDD[UnsafeRow](sc, Nil) {
+
+  override protected def getPartitions: Array[Partition] = {
+    (0 until numPartitions).map { partIndex =>
+      ContinuousShuffleReadPartition(partIndex, queueSize)
+    }.toArray
+  }
+
+  override def compute(split: Partition, context: TaskContext): 
Iterator[UnsafeRow] = {
+    split.asInstanceOf[ContinuousShuffleReadPartition].reader.read()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
new file mode 100644
index 0000000..42631c9
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReader.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+/**
+ * Trait for reading from a continuous processing shuffle.
+ */
+trait ContinuousShuffleReader {
+  /**
+   * Returns an iterator over the incoming rows in an epoch. Implementations 
should block waiting
+   * for new rows to arrive, and end the iterator once they've received epoch 
markers from all
+   * shuffle writers.
+   */
+  def read(): Iterator[UnsafeRow]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
new file mode 100644
index 0000000..b8adbb7
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/UnsafeRowReceiver.scala
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import java.util.concurrent.{ArrayBlockingQueue, BlockingQueue}
+import java.util.concurrent.atomic.AtomicBoolean
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.util.NextIterator
+
+/**
+ * Messages for the UnsafeRowReceiver endpoint. Either an incoming row or an 
epoch marker.
+ */
+private[shuffle] sealed trait UnsafeRowReceiverMessage extends Serializable
+private[shuffle] case class ReceiverRow(row: UnsafeRow) extends 
UnsafeRowReceiverMessage
+private[shuffle] case class ReceiverEpochMarker() extends 
UnsafeRowReceiverMessage
+
+/**
+ * RPC endpoint for receiving rows into a continuous processing shuffle task. 
Continuous shuffle
+ * writers will send rows here, with continuous shuffle readers polling for 
new rows as needed.
+ *
+ * TODO: Support multiple source tasks. We need to output a single epoch 
marker once all
+ * source tasks have sent one.
+ */
+private[shuffle] class UnsafeRowReceiver(
+      queueSize: Int,
+      override val rpcEnv: RpcEnv)
+    extends ThreadSafeRpcEndpoint with ContinuousShuffleReader with Logging {
+  // Note that this queue will be drained from the main task thread and 
populated in the RPC
+  // response thread.
+  private val queue = new 
ArrayBlockingQueue[UnsafeRowReceiverMessage](queueSize)
+
+  // Exposed for testing to determine if the endpoint gets stopped on task end.
+  private[shuffle] val stopped = new AtomicBoolean(false)
+
+  override def onStop(): Unit = {
+    stopped.set(true)
+  }
+
+  override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, 
Unit] = {
+    case r: UnsafeRowReceiverMessage =>
+      queue.put(r)
+      context.reply(())
+  }
+
+  override def read(): Iterator[UnsafeRow] = {
+    new NextIterator[UnsafeRow] {
+      override def getNext(): UnsafeRow = queue.take() match {
+        case ReceiverRow(r) => r
+        case ReceiverEpochMarker() =>
+          finished = true
+          null
+      }
+
+      override def close(): Unit = {}
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a33dcf4a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
new file mode 100644
index 0000000..b25e75b
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleReadSuite.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous.shuffle
+
+import org.apache.spark.{TaskContext, TaskContextImpl}
+import org.apache.spark.rpc.RpcEndpointRef
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.types.{DataType, IntegerType}
+
+class ContinuousShuffleReadSuite extends StreamTest {
+
+  private def unsafeRow(value: Int) = {
+    UnsafeProjection.create(Array(IntegerType : DataType))(
+      new GenericInternalRow(Array(value: Any)))
+  }
+
+  private def send(endpoint: RpcEndpointRef, messages: 
UnsafeRowReceiverMessage*) = {
+    messages.foreach(endpoint.askSync[Unit](_))
+  }
+
+  // In this unit test, we emulate that we're in the task thread where
+  // ContinuousShuffleReadRDD.compute() will be evaluated. This requires a 
task context
+  // thread local to be set.
+  var ctx: TaskContextImpl = _
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    ctx = TaskContext.empty()
+    TaskContext.setTaskContext(ctx)
+  }
+
+  override def afterEach(): Unit = {
+    ctx.markTaskCompleted(None)
+    TaskContext.unset()
+    ctx = null
+    super.afterEach()
+  }
+
+  test("receiver stopped with row last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverEpochMarker(),
+      ReceiverRow(unsafeRow(111))
+    )
+
+    ctx.markTaskCompleted(None)
+    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
+    }
+  }
+
+  test("receiver stopped with marker last") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    endpoint.askSync[Unit](ReceiverRow(unsafeRow(111)))
+    endpoint.askSync[Unit](ReceiverEpochMarker())
+
+    ctx.markTaskCompleted(None)
+    val receiver = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].reader
+    eventually(timeout(streamingTimeout)) {
+      assert(receiver.asInstanceOf[UnsafeRowReceiver].stopped.get())
+    }
+  }
+
+  test("one epoch") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(unsafeRow(111)),
+      ReceiverRow(unsafeRow(222)),
+      ReceiverRow(unsafeRow(333)),
+      ReceiverEpochMarker()
+    )
+
+    val iter = rdd.compute(rdd.partitions(0), ctx)
+    assert(iter.toSeq.map(_.getInt(0)) == Seq(111, 222, 333))
+  }
+
+  test("multiple epochs") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverRow(unsafeRow(111)),
+      ReceiverEpochMarker(),
+      ReceiverRow(unsafeRow(222)),
+      ReceiverRow(unsafeRow(333)),
+      ReceiverEpochMarker()
+    )
+
+    val firstEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(firstEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+    val secondEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(secondEpoch.toSeq.map(_.getInt(0)) == Seq(222, 333))
+  }
+
+  test("empty epochs") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+    val endpoint = 
rdd.partitions(0).asInstanceOf[ContinuousShuffleReadPartition].endpoint
+    send(
+      endpoint,
+      ReceiverEpochMarker(),
+      ReceiverEpochMarker(),
+      ReceiverRow(unsafeRow(111)),
+      ReceiverEpochMarker(),
+      ReceiverEpochMarker(),
+      ReceiverEpochMarker()
+    )
+
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+
+    val thirdEpoch = rdd.compute(rdd.partitions(0), ctx)
+    assert(thirdEpoch.toSeq.map(_.getInt(0)) == Seq(111))
+
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+    assert(rdd.compute(rdd.partitions(0), ctx).isEmpty)
+  }
+
+  test("multiple partitions") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
+    // Send all data before processing to ensure there's no crossover.
+    for (p <- rdd.partitions) {
+      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+      // Send index for identification.
+      send(
+        part.endpoint,
+        ReceiverRow(unsafeRow(part.index)),
+        ReceiverEpochMarker()
+      )
+    }
+
+    for (p <- rdd.partitions) {
+      val part = p.asInstanceOf[ContinuousShuffleReadPartition]
+      val iter = rdd.compute(part, ctx)
+      assert(iter.next().getInt(0) == part.index)
+      assert(!iter.hasNext)
+    }
+  }
+
+  test("blocks waiting for new rows") {
+    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 1)
+
+    val readRowThread = new Thread {
+      override def run(): Unit = {
+        // set the non-inheritable thread local
+        TaskContext.setTaskContext(ctx)
+        val epoch = rdd.compute(rdd.partitions(0), ctx)
+        epoch.next().getInt(0)
+      }
+    }
+
+    try {
+      readRowThread.start()
+      eventually(timeout(streamingTimeout)) {
+        assert(readRowThread.getState == Thread.State.WAITING)
+      }
+    } finally {
+      readRowThread.interrupt()
+      readRowThread.join()
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to