Repository: spark
Updated Branches:
  refs/heads/master 2224861f2 -> f6e6899a8


[SPARK-24386][SS] coalesce(1) aggregates in continuous processing

## What changes were proposed in this pull request?

Provide a continuous processing implementation of coalesce(1), as well as 
allowing aggregates on top of it.

The changes in ContinuousQueuedDataReader and such are to use split.index (the 
ID of the partition within the RDD currently being compute()d) rather than 
context.partitionId() (the partition ID of the scheduled task within the Spark 
job - that is, the post coalesce writer). In the absence of a narrow 
dependency, these values were previously always the same, so there was no need 
to distinguish.

## How was this patch tested?

new unit test

Author: Jose Torres <torres.joseph.f+git...@gmail.com>

Closes #21560 from jose-torres/coalesce.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6e6899a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6e6899a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6e6899a

Branch: refs/heads/master
Commit: f6e6899a8b8af99cd06e84cae7c69e0fc35bc60a
Parents: 2224861
Author: Jose Torres <torres.joseph.f+git...@gmail.com>
Authored: Thu Jun 28 16:25:40 2018 -0700
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Thu Jun 28 16:25:40 2018 -0700

----------------------------------------------------------------------
 .../analysis/UnsupportedOperationChecker.scala  |  11 ++
 .../datasources/v2/DataSourceV2Strategy.scala   |  16 ++-
 .../continuous/ContinuousCoalesceExec.scala     |  51 +++++++
 .../continuous/ContinuousCoalesceRDD.scala      | 136 +++++++++++++++++++
 .../continuous/ContinuousDataSourceRDD.scala    |   7 +-
 .../continuous/ContinuousExecution.scala        |   4 +
 .../continuous/ContinuousQueuedDataReader.scala |   6 +-
 .../shuffle/ContinuousShuffleReadRDD.scala      |  10 +-
 .../shuffle/RPCContinuousShuffleReader.scala    |   4 +-
 .../sources/ContinuousMemoryStream.scala        |  11 +-
 .../continuous/ContinuousAggregationSuite.scala |  63 ++++++++-
 .../ContinuousQueuedDataReaderSuite.scala       |   2 +-
 .../shuffle/ContinuousShuffleSuite.scala        |   7 +-
 13 files changed, 310 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 2bed416..5ced1ca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -349,6 +349,17 @@ object UnsupportedOperationChecker {
               _: DeserializeToObject | _: SerializeFromObject | _: 
SubqueryAlias |
               _: TypedFilter) =>
         case node if node.nodeName == "StreamingRelationV2" =>
+        case Repartition(1, false, _) =>
+        case node: Aggregate =>
+          val aboveSinglePartitionCoalesce = node.find {
+            case Repartition(1, false, _) => true
+            case _ => false
+          }.isDefined
+
+          if (!aboveSinglePartitionCoalesce) {
+            throwError(s"In continuous processing mode, coalesce(1) must be 
called before " +
+              s"aggregate operation ${node.nodeName}.")
+          }
         case node =>
           throwError(s"Continuous processing does not support ${node.nodeName} 
operations.")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 182aa29..2a7f1de 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -22,11 +22,12 @@ import scala.collection.mutable
 import org.apache.spark.sql.{sources, Strategy}
 import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, 
AttributeSet, Expression}
 import org.apache.spark.sql.catalyst.planning.PhysicalOperation
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
 import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-import 
org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource,
 WriteToContinuousDataSourceExec}
+import 
org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, 
WriteToContinuousDataSource, WriteToContinuousDataSourceExec}
 import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, 
SupportsPushDownCatalystFilters, SupportsPushDownFilters, 
SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousReader
 
 object DataSourceV2Strategy extends Strategy {
 
@@ -141,6 +142,17 @@ object DataSourceV2Strategy extends Strategy {
     case WriteToContinuousDataSource(writer, query) =>
       WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil
 
+    case Repartition(1, false, child) =>
+      val isContinuous = child.collectFirst {
+        case StreamingDataSourceV2Relation(_, _, _, r: ContinuousReader) => r
+      }.isDefined
+
+      if (isContinuous) {
+        ContinuousCoalesceExec(1, planLater(child)) :: Nil
+      } else {
+        Nil
+      }
+
     case _ => Nil
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala
new file mode 100644
index 0000000..5f60343
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceExec.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.UUID
+
+import org.apache.spark.{HashPartitioner, SparkEnv}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, 
SinglePartition}
+import org.apache.spark.sql.execution.SparkPlan
+import 
org.apache.spark.sql.execution.streaming.continuous.shuffle.{ContinuousShuffleReadPartition,
 ContinuousShuffleReadRDD}
+
+/**
+ * Physical plan for coalescing a continuous processing plan.
+ *
+ * Currently, only coalesces to a single partition are supported. 
`numPartitions` must be 1.
+ */
+case class ContinuousCoalesceExec(numPartitions: Int, child: SparkPlan) 
extends SparkPlan {
+  override def output: Seq[Attribute] = child.output
+
+  override def children: Seq[SparkPlan] = child :: Nil
+
+  override def outputPartitioning: Partitioning = SinglePartition
+
+  override def doExecute(): RDD[InternalRow] = {
+    assert(numPartitions == 1)
+    new ContinuousCoalesceRDD(
+      sparkContext,
+      numPartitions,
+      conf.continuousStreamingExecutorQueueSize,
+      
sparkContext.getLocalProperty(ContinuousExecution.EPOCH_INTERVAL_KEY).toLong,
+      child.execute())
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala
new file mode 100644
index 0000000..ba85b35
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousCoalesceRDD.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.streaming.continuous
+
+import java.util.UUID
+
+import org.apache.spark._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.streaming.continuous.shuffle._
+import org.apache.spark.util.ThreadUtils
+
+case class ContinuousCoalesceRDDPartition(
+    index: Int,
+    endpointName: String,
+    queueSize: Int,
+    numShuffleWriters: Int,
+    epochIntervalMs: Long)
+  extends Partition {
+  // Initialized only on the executor, and only once even as we call compute() 
multiple times.
+  lazy val (reader: ContinuousShuffleReader, endpoint) = {
+    val env = SparkEnv.get.rpcEnv
+    val receiver = new RPCContinuousShuffleReader(
+      queueSize, numShuffleWriters, epochIntervalMs, env)
+    val endpoint = env.setupEndpoint(endpointName, receiver)
+
+    TaskContext.get().addTaskCompletionListener { ctx =>
+      env.stop(endpoint)
+    }
+    (receiver, endpoint)
+  }
+  // This flag will be flipped on the executors to indicate that the threads 
processing
+  // partitions of the write-side RDD have been started. These will run 
indefinitely
+  // asynchronously as epochs of the coalesce RDD complete on the read side.
+  private[continuous] var writersInitialized: Boolean = false
+}
+
+/**
+ * RDD for continuous coalescing. Asynchronously writes all partitions of 
`prev` into a local
+ * continuous shuffle, and then reads them in the task thread using `reader`.
+ */
+class ContinuousCoalesceRDD(
+    context: SparkContext,
+    numPartitions: Int,
+    readerQueueSize: Int,
+    epochIntervalMs: Long,
+    prev: RDD[InternalRow])
+  extends RDD[InternalRow](context, Nil) {
+
+  // When we support more than 1 target partition, we'll need to figure out 
how to pass in the
+  // required partitioner.
+  private val outputPartitioner = new HashPartitioner(1)
+
+  private val readerEndpointNames = (0 until numPartitions).map { i =>
+    s"ContinuousCoalesceRDD-part$i-${UUID.randomUUID()}"
+  }
+
+  override def getPartitions: Array[Partition] = {
+    (0 until numPartitions).map { partIndex =>
+      ContinuousCoalesceRDDPartition(
+        partIndex,
+        readerEndpointNames(partIndex),
+        readerQueueSize,
+        prev.getNumPartitions,
+        epochIntervalMs)
+    }.toArray
+  }
+
+  private lazy val threadPool = ThreadUtils.newDaemonFixedThreadPool(
+    prev.getNumPartitions,
+    this.name)
+
+  override def compute(split: Partition, context: TaskContext): 
Iterator[InternalRow] = {
+    val part = split.asInstanceOf[ContinuousCoalesceRDDPartition]
+
+    if (!part.writersInitialized) {
+      val rpcEnv = SparkEnv.get.rpcEnv
+
+      // trigger lazy initialization
+      part.endpoint
+      val endpointRefs = readerEndpointNames.map { endpointName =>
+        rpcEnv.setupEndpointRef(rpcEnv.address, endpointName)
+      }
+
+      val runnables = prev.partitions.map { prevSplit =>
+        new Runnable() {
+          override def run(): Unit = {
+            TaskContext.setTaskContext(context)
+
+            val writer: ContinuousShuffleWriter = new 
RPCContinuousShuffleWriter(
+              prevSplit.index, outputPartitioner, endpointRefs.toArray)
+
+            EpochTracker.initializeCurrentEpoch(
+              
context.getLocalProperty(ContinuousExecution.START_EPOCH_KEY).toLong)
+            while (!context.isInterrupted() && !context.isCompleted()) {
+              writer.write(prev.compute(prevSplit, 
context).asInstanceOf[Iterator[UnsafeRow]])
+              // Note that current epoch is a non-inheritable thread local, so 
each writer thread
+              // can properly increment its own epoch without affecting the 
main task thread.
+              EpochTracker.incrementCurrentEpoch()
+            }
+          }
+        }
+      }
+
+      context.addTaskCompletionListener { ctx =>
+        threadPool.shutdownNow()
+      }
+
+      part.writersInitialized = true
+
+      runnables.foreach(threadPool.execute)
+    }
+
+    part.reader.read()
+  }
+
+  override def clearDependencies(): Unit = {
+    throw new IllegalStateException("Continuous RDDs cannot be checkpointed")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
index a7ccce1..73868d5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDD.scala
@@ -51,11 +51,11 @@ class ContinuousDataSourceRDD(
     sc: SparkContext,
     dataQueueSize: Int,
     epochPollIntervalMs: Long,
-    @transient private val readerFactories: Seq[InputPartition[UnsafeRow]])
+    private val readerInputPartitions: Seq[InputPartition[UnsafeRow]])
   extends RDD[UnsafeRow](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {
-    readerFactories.zipWithIndex.map {
+    readerInputPartitions.zipWithIndex.map {
       case (inputPartition, index) => new 
ContinuousDataSourceRDDPartition(index, inputPartition)
     }.toArray
   }
@@ -74,8 +74,7 @@ class ContinuousDataSourceRDD(
       val partition = split.asInstanceOf[ContinuousDataSourceRDDPartition]
       if (partition.queueReader == null) {
         partition.queueReader =
-          new ContinuousQueuedDataReader(
-            partition.inputPartition, context, dataQueueSize, 
epochPollIntervalMs)
+          new ContinuousQueuedDataReader(partition, context, dataQueueSize, 
epochPollIntervalMs)
       }
 
       partition.queueReader

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index e3d0cea..a0bb829 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -216,6 +216,9 @@ class ContinuousExecution(
     currentEpochCoordinatorId = epochCoordinatorId
     sparkSessionForQuery.sparkContext.setLocalProperty(
       ContinuousExecution.EPOCH_COORDINATOR_ID_KEY, epochCoordinatorId)
+    sparkSessionForQuery.sparkContext.setLocalProperty(
+      ContinuousExecution.EPOCH_INTERVAL_KEY,
+      trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString)
 
     // Use the parent Spark session for the endpoint since it's where this 
query ID is registered.
     val epochEndpoint =
@@ -382,4 +385,5 @@ class ContinuousExecution(
 object ContinuousExecution {
   val START_EPOCH_KEY = "__continuous_start_epoch"
   val EPOCH_COORDINATOR_ID_KEY = "__epoch_coordinator_id"
+  val EPOCH_INTERVAL_KEY = "__continuous_epoch_interval"
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
index f38577b..8c74b82 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousQueuedDataReader.scala
@@ -37,11 +37,11 @@ import org.apache.spark.util.ThreadUtils
  * offsets across epochs. Each compute() should call the next() method here 
until null is returned.
  */
 class ContinuousQueuedDataReader(
-    partition: InputPartition[UnsafeRow],
+    partition: ContinuousDataSourceRDDPartition,
     context: TaskContext,
     dataQueueSize: Int,
     epochPollIntervalMs: Long) extends Closeable {
-  private val reader = partition.createPartitionReader()
+  private val reader = partition.inputPartition.createPartitionReader()
 
   // Important sequencing - we must get our starting point before the provider 
threads start running
   private var currentOffset: PartitionOffset =
@@ -113,7 +113,7 @@ class ContinuousQueuedDataReader(
     currentEntry match {
       case EpochMarker =>
         epochCoordEndpoint.send(ReportPartitionOffset(
-          context.partitionId(), EpochTracker.getCurrentEpoch.get, 
currentOffset))
+          partition.index, EpochTracker.getCurrentEpoch.get, currentOffset))
         null
       case ContinuousRow(row, offset) =>
         currentOffset = offset

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
index cf6572d..518223f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/ContinuousShuffleReadRDD.scala
@@ -21,12 +21,14 @@ import java.util.UUID
 
 import org.apache.spark.{Partition, SparkContext, SparkEnv, TaskContext}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.rpc.RpcAddress
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.NextIterator
 
 case class ContinuousShuffleReadPartition(
       index: Int,
+      endpointName: String,
       queueSize: Int,
       numShuffleWriters: Int,
       epochIntervalMs: Long)
@@ -36,7 +38,7 @@ case class ContinuousShuffleReadPartition(
     val env = SparkEnv.get.rpcEnv
     val receiver = new RPCContinuousShuffleReader(
       queueSize, numShuffleWriters, epochIntervalMs, env)
-    val endpoint = 
env.setupEndpoint(s"RPCContinuousShuffleReader-${UUID.randomUUID()}", receiver)
+    val endpoint = env.setupEndpoint(endpointName, receiver)
 
     TaskContext.get().addTaskCompletionListener { ctx =>
       env.stop(endpoint)
@@ -61,12 +63,14 @@ class ContinuousShuffleReadRDD(
     numPartitions: Int,
     queueSize: Int = 1024,
     numShuffleWriters: Int = 1,
-    epochIntervalMs: Long = 1000)
+    epochIntervalMs: Long = 1000,
+    val endpointNames: Seq[String] = 
Seq(s"RPCContinuousShuffleReader-${UUID.randomUUID()}"))
   extends RDD[UnsafeRow](sc, Nil) {
 
   override protected def getPartitions: Array[Partition] = {
     (0 until numPartitions).map { partIndex =>
-      ContinuousShuffleReadPartition(partIndex, queueSize, numShuffleWriters, 
epochIntervalMs)
+      ContinuousShuffleReadPartition(
+        partIndex, endpointNames(partIndex), queueSize, numShuffleWriters, 
epochIntervalMs)
     }.toArray
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
index 834e846..502ae0d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/shuffle/RPCContinuousShuffleReader.scala
@@ -46,7 +46,7 @@ private[shuffle] case class ReceiverEpochMarker(writerId: 
Int) extends RPCContin
  * TODO: Support multiple source tasks. We need to output a single epoch 
marker once all
  * source tasks have sent one.
  */
-private[shuffle] class RPCContinuousShuffleReader(
+private[continuous] class RPCContinuousShuffleReader(
       queueSize: Int,
       numShuffleWriters: Int,
       epochIntervalMs: Long,
@@ -107,7 +107,7 @@ private[shuffle] class RPCContinuousShuffleReader(
               }
               logWarning(
                 s"Completion service failed to make progress after 
$epochIntervalMs ms. Waiting " +
-                  s"for writers $writerIdsUncommitted to send epoch markers.")
+                  s"for writers ${writerIdsUncommitted.mkString(",")} to send 
epoch markers.")
 
             // The completion service guarantees this future will be available 
immediately.
             case future => future.get() match {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
index d1c3498..0bf90b8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala
@@ -23,12 +23,13 @@ import java.util.concurrent.atomic.AtomicInteger
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.JavaConverters._
+import scala.collection.SortedMap
 import scala.collection.mutable.ListBuffer
 
 import org.json4s.NoTypeHints
 import org.json4s.jackson.Serialization
 
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkEnv, TaskContext}
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, 
ThreadSafeRpcEndpoint}
 import org.apache.spark.sql.{Encoder, Row, SQLContext}
 import org.apache.spark.sql.execution.streaming._
@@ -184,6 +185,14 @@ class ContinuousMemoryStreamInputPartitionReader(
   private var currentOffset = startOffset
   private var current: Option[Row] = None
 
+  // Defense-in-depth against failing to propagate the task context. Since 
it's not inheritable,
+  // we have to do a bit of error prone work to get it into every thread used 
by continuous
+  // processing. We hope that some unit test will end up instantiating a 
continuous memory stream
+  // in such cases.
+  if (TaskContext.get() == null) {
+    throw new IllegalStateException("Task context was not set!")
+  }
+
   override def next(): Boolean = {
     current = getRecord
     while (current.isEmpty) {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
index b7ef637..0223812 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousAggregationSuite.scala
@@ -31,7 +31,8 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase {
       testStream(input.toDF().agg(max('value)), OutputMode.Complete)()
     }
 
-    assert(ex.getMessage.contains("Continuous processing does not support 
Aggregate operations"))
+    assert(ex.getMessage.contains(
+      "In continuous processing mode, coalesce(1) must be called before 
aggregate operation"))
   }
 
   test("basic") {
@@ -50,6 +51,66 @@ class ContinuousAggregationSuite extends ContinuousSuiteBase 
{
     }
   }
 
+  test("multiple partitions with coalesce") {
+    val input = ContinuousMemoryStream[Int]
+
+    val df = input.toDF().coalesce(1).agg(max('value))
+
+    testStream(df, OutputMode.Complete)(
+      AddData(input, 0, 1, 2),
+      CheckAnswer(2),
+      StopStream,
+      AddData(input, 3, 4, 5),
+      StartStream(),
+      CheckAnswer(5),
+      AddData(input, -1, -2, -3),
+      CheckAnswer(5))
+  }
+
+  test("multiple partitions with coalesce - multiple transformations") {
+    val input = ContinuousMemoryStream[Int]
+
+    // We use a barrier to make sure predicates both before and after coalesce 
work
+    val df = input.toDF()
+      .select('value as 'copy, 'value)
+      .where('copy =!= 1)
+      .planWithBarrier
+      .coalesce(1)
+      .where('copy =!= 2)
+      .agg(max('value))
+
+    testStream(df, OutputMode.Complete)(
+      AddData(input, 0, 1, 2),
+      CheckAnswer(0),
+      StopStream,
+      AddData(input, 3, 4, 5),
+      StartStream(),
+      CheckAnswer(5),
+      AddData(input, -1, -2, -3),
+      CheckAnswer(5))
+  }
+
+  test("multiple partitions with multiple coalesce") {
+    val input = ContinuousMemoryStream[Int]
+
+    val df = input.toDF()
+      .coalesce(1)
+      .planWithBarrier
+      .coalesce(1)
+      .select('value as 'copy, 'value)
+      .agg(max('value))
+
+    testStream(df, OutputMode.Complete)(
+      AddData(input, 0, 1, 2),
+      CheckAnswer(2),
+      StopStream,
+      AddData(input, 3, 4, 5),
+      StartStream(),
+      CheckAnswer(5),
+      AddData(input, -1, -2, -3),
+      CheckAnswer(5))
+  }
+
   test("repeated restart") {
     withSQLConf(("spark.sql.streaming.unsupportedOperationCheck", "false")) {
       val input = ContinuousMemoryStream.singlePartition[Int]

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
index e663fa8..0e7e6fe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala
@@ -92,7 +92,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with 
MockitoSugar {
       }
     }
     val reader = new ContinuousQueuedDataReader(
-      factory,
+      new ContinuousDataSourceRDDPartition(0, factory),
       mockContext,
       dataQueueSize = sqlContext.conf.continuousStreamingExecutorQueueSize,
       epochPollIntervalMs = 
sqlContext.conf.continuousStreamingExecutorPollIntervalMs)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6e6899a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
index a8e3611..f84f3d4 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/shuffle/ContinuousShuffleSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.execution.streaming.continuous.shuffle
 
+import java.util.UUID
+
 import org.apache.spark.{HashPartitioner, Partition, TaskContext, 
TaskContextImpl}
 import org.apache.spark.rpc.RpcEndpointRef
 import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection, UnsafeRow}
@@ -124,7 +126,10 @@ class ContinuousShuffleSuite extends StreamTest {
   }
 
   test("reader - multiple partitions") {
-    val rdd = new ContinuousShuffleReadRDD(sparkContext, numPartitions = 5)
+    val rdd = new ContinuousShuffleReadRDD(
+      sparkContext,
+      numPartitions = 5,
+      endpointNames = Seq.fill(5)(s"endpt-${UUID.randomUUID()}"))
     // Send all data before processing to ensure there's no crossover.
     for (p <- rdd.partitions) {
       val part = p.asInstanceOf[ContinuousShuffleReadPartition]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to