Repository: spark
Updated Branches:
  refs/heads/branch-2.3 b94debd2b -> f891ee324


http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3304f36..97f12ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -255,17 +255,24 @@ final class DataFrameWriter[T] private[sql](ds: 
Dataset[T]) {
             }
           }
 
-        case _ => throw new AnalysisException(s"$cls does not support data 
writing.")
+        // Streaming also uses the data source V2 API. So it may be that the 
data source implements
+        // v2, but has no v2 implementation for batch writes. In that case, we 
fall back to saving
+        // as though it's a V1 source.
+        case _ => saveToV1Source()
       }
     } else {
-      // Code path for data source v1.
-      runCommand(df.sparkSession, "save") {
-        DataSource(
-          sparkSession = df.sparkSession,
-          className = source,
-          partitionColumns = partitioningColumns.getOrElse(Nil),
-          options = extraOptions.toMap).planForWriting(mode, 
AnalysisBarrier(df.logicalPlan))
-      }
+      saveToV1Source()
+    }
+  }
+
+  private def saveToV1Source(): Unit = {
+    // Code path for data source v1.
+    runCommand(df.sparkSession, "save") {
+      DataSource(
+        sparkSession = df.sparkSession,
+        className = source,
+        partitionColumns = partitioningColumns.getOrElse(Nil),
+        options = extraOptions.toMap).planForWriting(mode, 
AnalysisBarrier(df.logicalPlan))
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
index f0bdf84..a4a857f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala
@@ -81,9 +81,11 @@ case class WriteToDataSourceV2Exec(writer: 
DataSourceV2Writer, query: SparkPlan)
         (index, message: WriterCommitMessage) => messages(index) = message
       )
 
-      logInfo(s"Data source writer $writer is committing.")
-      writer.commit(messages)
-      logInfo(s"Data source writer $writer committed.")
+      if (!writer.isInstanceOf[ContinuousWriter]) {
+        logInfo(s"Data source writer $writer is committing.")
+        writer.commit(messages)
+        logInfo(s"Data source writer $writer committed.")
+      }
     } catch {
       case _: InterruptedException if writer.isInstanceOf[ContinuousWriter] =>
         // Interruption is how continuous queries are ended, so accept and 
ignore the exception.

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 24a8b00..cf27e1a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -142,7 +142,8 @@ abstract class StreamExecution(
 
   override val id: UUID = UUID.fromString(streamMetadata.id)
 
-  override val runId: UUID = UUID.randomUUID
+  override def runId: UUID = currentRunId
+  protected var currentRunId = UUID.randomUUID
 
   /**
    * Pretty identified string of printing in logs. Format is
@@ -418,11 +419,17 @@ abstract class StreamExecution(
    * Blocks the current thread until processing for data from the given 
`source` has reached at
    * least the given `Offset`. This method is intended for use primarily when 
writing tests.
    */
-  private[sql] def awaitOffset(source: BaseStreamingSource, newOffset: 
Offset): Unit = {
+  private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset): Unit = {
     assertAwaitThread()
     def notDone = {
       val localCommittedOffsets = committedOffsets
-      !localCommittedOffsets.contains(source) || localCommittedOffsets(source) 
!= newOffset
+      if (sources == null) {
+        // sources might not be initialized yet
+        false
+      } else {
+        val source = sources(sourceIndex)
+        !localCommittedOffsets.contains(source) || 
localCommittedOffsets(source) != newOffset
+      }
     }
 
     while (notDone) {
@@ -436,7 +443,7 @@ abstract class StreamExecution(
         awaitProgressLock.unlock()
       }
     }
-    logDebug(s"Unblocked at $newOffset for $source")
+    logDebug(s"Unblocked at $newOffset for ${sources(sourceIndex)}")
   }
 
   /** A flag to indicate that a batch has completed with no new data 
available. */

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
index d79e4bd..e700aa4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousDataSourceRDDIter.scala
@@ -77,7 +77,6 @@ class ContinuousDataSourceRDD(
     dataReaderThread.start()
 
     context.addTaskCompletionListener(_ => {
-      reader.close()
       dataReaderThread.interrupt()
       epochPollExecutor.shutdown()
     })
@@ -201,6 +200,8 @@ class DataReaderThread(
         failedFlag.set(true)
         // Don't rethrow the exception in this thread. It's not needed, and 
the default Spark
         // exception handler will kill the executor.
+    } finally {
+      reader.close()
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index 9657b5e..667410e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.execution.streaming.continuous
 
+import java.util.UUID
 import java.util.concurrent.TimeUnit
+import java.util.function.UnaryOperator
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, Map => MutableMap}
@@ -52,7 +54,7 @@ class ContinuousExecution(
     sparkSession, name, checkpointRoot, analyzedPlan, sink,
     trigger, triggerClock, outputMode, deleteCheckpointOnStop) {
 
-  @volatile protected var continuousSources: Seq[ContinuousReader] = Seq.empty
+  @volatile protected var continuousSources: Seq[ContinuousReader] = _
   override protected def sources: Seq[BaseStreamingSource] = continuousSources
 
   override lazy val logicalPlan: LogicalPlan = {
@@ -78,15 +80,17 @@ class ContinuousExecution(
   }
 
   override protected def runActivatedStream(sparkSessionForStream: 
SparkSession): Unit = {
-    do {
-      try {
-        runContinuous(sparkSessionForStream)
-      } catch {
-        case _: InterruptedException if state.get().equals(RECONFIGURING) =>
-          // swallow exception and run again
-          state.set(ACTIVE)
+    val stateUpdate = new UnaryOperator[State] {
+      override def apply(s: State) = s match {
+        // If we ended the query to reconfigure, reset the state to active.
+        case RECONFIGURING => ACTIVE
+        case _ => s
       }
-    } while (state.get() == ACTIVE)
+    }
+
+    do {
+      runContinuous(sparkSessionForStream)
+    } while (state.updateAndGet(stateUpdate) == ACTIVE)
   }
 
   /**
@@ -120,12 +124,16 @@ class ContinuousExecution(
         }
         committedOffsets = nextOffsets.toStreamProgress(sources)
 
-        // Forcibly align commit and offset logs by slicing off any spurious 
offset logs from
-        // a previous run. We can't allow commits to an epoch that a previous 
run reached but
-        // this run has not.
-        offsetLog.purgeAfter(latestEpochId)
+        // Get to an epoch ID that has definitely never been sent to a sink 
before. Since sink
+        // commit happens between offset log write and commit log write, this 
means an epoch ID
+        // which is not in the offset log.
+        val (latestOffsetEpoch, _) = offsetLog.getLatest().getOrElse {
+          throw new IllegalStateException(
+            s"Offset log had no latest element. This shouldn't be possible 
because nextOffsets is" +
+              s"an element.")
+        }
+        currentBatchId = latestOffsetEpoch + 1
 
-        currentBatchId = latestEpochId + 1
         logDebug(s"Resuming at epoch $currentBatchId with committed offsets 
$committedOffsets")
         nextOffsets
       case None =>
@@ -141,6 +149,7 @@ class ContinuousExecution(
    * @param sparkSessionForQuery Isolated [[SparkSession]] to run the 
continuous query with.
    */
   private def runContinuous(sparkSessionForQuery: SparkSession): Unit = {
+    currentRunId = UUID.randomUUID
     // A list of attributes that will need to be updated.
     val replacements = new ArrayBuffer[(Attribute, Attribute)]
     // Translate from continuous relation to the underlying data source.
@@ -225,13 +234,11 @@ class ContinuousExecution(
           triggerExecutor.execute(() => {
             startTrigger()
 
-            if (reader.needsReconfiguration()) {
-              state.set(RECONFIGURING)
+            if (reader.needsReconfiguration() && state.compareAndSet(ACTIVE, 
RECONFIGURING)) {
               stopSources()
               if (queryExecutionThread.isAlive) {
                 sparkSession.sparkContext.cancelJobGroup(runId.toString)
                 queryExecutionThread.interrupt()
-                // No need to join - this thread is about to end anyway.
               }
               false
             } else if (isActive) {
@@ -259,6 +266,7 @@ class ContinuousExecution(
           sparkSessionForQuery, lastExecution)(lastExecution.toRdd)
       }
     } finally {
+      epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)
       SparkEnv.get.rpcEnv.stop(epochEndpoint)
 
       epochUpdateThread.interrupt()
@@ -273,17 +281,22 @@ class ContinuousExecution(
       epoch: Long, reader: ContinuousReader, partitionOffsets: 
Seq[PartitionOffset]): Unit = {
     assert(continuousSources.length == 1, "only one continuous source 
supported currently")
 
-    if (partitionOffsets.contains(null)) {
-      // If any offset is null, that means the corresponding partition hasn't 
seen any data yet, so
-      // there's nothing meaningful to add to the offset log.
-    }
     val globalOffset = reader.mergeOffsets(partitionOffsets.toArray)
-    synchronized {
-      if (queryExecutionThread.isAlive) {
-        offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
-      } else {
-        return
-      }
+    val oldOffset = synchronized {
+      offsetLog.add(epoch, OffsetSeq.fill(globalOffset))
+      offsetLog.get(epoch - 1)
+    }
+
+    // If offset hasn't changed since last epoch, there's been no new data.
+    if (oldOffset.contains(OffsetSeq.fill(globalOffset))) {
+      noNewData = true
+    }
+
+    awaitProgressLock.lock()
+    try {
+      awaitProgressLockCondition.signalAll()
+    } finally {
+      awaitProgressLock.unlock()
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
index 98017c3..40dcbec 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala
@@ -39,6 +39,15 @@ private[continuous] sealed trait EpochCoordinatorMessage 
extends Serializable
  */
 private[sql] case object IncrementAndGetEpoch extends EpochCoordinatorMessage
 
+/**
+ * The RpcEndpoint stop() will wait to clear out the message queue before 
terminating the
+ * object. This can lead to a race condition where the query restarts at epoch 
n, a new
+ * EpochCoordinator starts at epoch n, and then the old epoch coordinator 
commits epoch n + 1.
+ * The framework doesn't provide a handle to wait on the message queue, so we 
use a synchronous
+ * message to stop any writes to the ContinuousExecution object.
+ */
+private[sql] case object StopContinuousExecutionWrites extends 
EpochCoordinatorMessage
+
 // Init messages
 /**
  * Set the reader and writer partition counts. Tasks may not be started until 
the coordinator
@@ -116,6 +125,8 @@ private[continuous] class EpochCoordinator(
     override val rpcEnv: RpcEnv)
   extends ThreadSafeRpcEndpoint with Logging {
 
+  private var queryWritesStopped: Boolean = false
+
   private var numReaderPartitions: Int = _
   private var numWriterPartitions: Int = _
 
@@ -147,12 +158,16 @@ private[continuous] class EpochCoordinator(
         partitionCommits.remove(k)
       }
       for (k <- partitionOffsets.keys.filter { case (e, _) => e < epoch }) {
-        partitionCommits.remove(k)
+        partitionOffsets.remove(k)
       }
     }
   }
 
   override def receive: PartialFunction[Any, Unit] = {
+    // If we just drop these messages, we won't do any writes to the query. 
The lame duck tasks
+    // won't shed errors or anything.
+    case _ if queryWritesStopped => ()
+
     case CommitPartitionEpoch(partitionId, epoch, message) =>
       logDebug(s"Got commit from partition $partitionId at epoch $epoch: 
$message")
       if (!partitionCommits.isDefinedAt((epoch, partitionId))) {
@@ -188,5 +203,9 @@ private[continuous] class EpochCoordinator(
     case SetWriterPartitions(numPartitions) =>
       numWriterPartitions = numPartitions
       context.reply(())
+
+    case StopContinuousExecutionWrites =>
+      queryWritesStopped = true
+      context.reply(())
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index db588ae..b5b4a05 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger
 import org.apache.spark.sql.execution.streaming.sources.{MemoryPlanV2, 
MemorySinkV2}
+import org.apache.spark.sql.sources.v2.streaming.ContinuousWriteSupport
 
 /**
  * Interface used to write a streaming `Dataset` to external storage systems 
(e.g. file systems,
@@ -279,18 +280,29 @@ final class DataStreamWriter[T] private[sql](ds: 
Dataset[T]) {
         useTempCheckpointLocation = true,
         trigger = trigger)
     } else {
-      val dataSource =
-        DataSource(
-          df.sparkSession,
-          className = source,
-          options = extraOptions.toMap,
-          partitionColumns = normalizedParCols.getOrElse(Nil))
+      val sink = trigger match {
+        case _: ContinuousTrigger =>
+          val ds = DataSource.lookupDataSource(source, 
df.sparkSession.sessionState.conf)
+          ds.newInstance() match {
+            case w: ContinuousWriteSupport => w
+            case _ => throw new AnalysisException(
+              s"Data source $source does not support continuous writing")
+          }
+        case _ =>
+          val ds = DataSource(
+            df.sparkSession,
+            className = source,
+            options = extraOptions.toMap,
+            partitionColumns = normalizedParCols.getOrElse(Nil))
+          ds.createSink(outputMode)
+      }
+
       df.sparkSession.sessionState.streamingQueryManager.startQuery(
         extraOptions.get("queryName"),
         extraOptions.get("checkpointLocation"),
         df,
         extraOptions.toMap,
-        dataSource.createSink(outputMode),
+        sink,
         outputMode,
         useTempCheckpointLocation = source == "console",
         recoverFromCheckpointLocation = true,

http://git-wip-us.apache.org/repos/asf/spark/blob/f891ee32/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index d46461f..0762895 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -38,8 +38,9 @@ import org.apache.spark.sql.{Dataset, Encoder, QueryTest, Row}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, 
RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.streaming._
-import 
org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, 
EpochCoordinatorRef, IncrementAndGetEpoch}
+import 
org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, 
ContinuousTrigger, EpochCoordinatorRef, IncrementAndGetEpoch}
 import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2
 import org.apache.spark.sql.execution.streaming.state.StateStore
 import org.apache.spark.sql.streaming.StreamingQueryListener._
@@ -80,6 +81,9 @@ trait StreamTest extends QueryTest with SharedSQLContext with 
TimeLimits with Be
     StateStore.stop() // stop the state store maintenance thread and unload 
store providers
   }
 
+  protected val defaultTrigger = Trigger.ProcessingTime(0)
+  protected val defaultUseV2Sink = false
+
   /** How long to wait for an active stream to catch up when checking a 
result. */
   val streamingTimeout = 10.seconds
 
@@ -189,7 +193,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
 
   /** Starts the stream, resuming if data has already been processed. It must 
not be running. */
   case class StartStream(
-      trigger: Trigger = Trigger.ProcessingTime(0),
+      trigger: Trigger = defaultTrigger,
       triggerClock: Clock = new SystemClock,
       additionalConfs: Map[String, String] = Map.empty,
       checkpointLocation: String = null)
@@ -276,7 +280,7 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
   def testStream(
       _stream: Dataset[_],
       outputMode: OutputMode = OutputMode.Append,
-      useV2Sink: Boolean = false)(actions: StreamAction*): Unit = synchronized 
{
+      useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = 
synchronized {
     import org.apache.spark.sql.streaming.util.StreamManualClock
 
     // `synchronized` is added to prevent the user from calling multiple 
`testStream`s concurrently
@@ -403,18 +407,11 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
 
     def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = 
{
       verify(currentStream != null, "stream not running")
-      // Get the map of source index to the current source objects
-      val indexToSource = currentStream
-        .logicalPlan
-        .collect { case StreamingExecutionRelation(s, _) => s }
-        .zipWithIndex
-        .map(_.swap)
-        .toMap
 
       // Block until all data added has been processed for all the source
       awaiting.foreach { case (sourceIndex, offset) =>
         failAfter(streamingTimeout) {
-          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+          currentStream.awaitOffset(sourceIndex, offset)
         }
       }
 
@@ -473,6 +470,12 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
             // after starting the query.
             try {
               currentStream.awaitInitialization(streamingTimeout.toMillis)
+              currentStream match {
+                case s: ContinuousExecution => 
eventually("IncrementalExecution was not created") {
+                    s.lastExecution.executedPlan // will fail if lastExecution 
is null
+                  }
+                case _ =>
+              }
             } catch {
               case _: StreamingQueryException =>
                 // Ignore the exception. `StopStream` or `ExpectFailure` will 
catch it as well.
@@ -600,7 +603,10 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
 
               def findSourceIndex(plan: LogicalPlan): Option[Int] = {
                 plan
-                  .collect { case StreamingExecutionRelation(s, _) => s }
+                  .collect {
+                    case StreamingExecutionRelation(s, _) => s
+                    case DataSourceV2Relation(_, r) => r
+                  }
                   .zipWithIndex
                   .find(_._1 == source)
                   .map(_._2)
@@ -613,9 +619,13 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
                   findSourceIndex(query.logicalPlan)
                 }.orElse {
                   findSourceIndex(stream.logicalPlan)
+                }.orElse {
+                  queryToUse.flatMap { q =>
+                    findSourceIndex(q.lastExecution.logical)
+                  }
                 }.getOrElse {
                   throw new IllegalArgumentException(
-                    "Could find index of the source to which data was added")
+                    "Could not find index of the source to which data was 
added")
                 }
 
               // Store the expected offset of added data to wait for it later


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to