Repository: spark
Updated Branches:
  refs/heads/master f77f11c67 -> 463bac001


[SPARK-14257][SQL] Allow multiple continuous queries to be started from the 
same DataFrame

## What changes were proposed in this pull request?

Make StreamingRelation store the closure to create the source in 
StreamExecution so that we can start multiple continuous queries from the same 
DataFrame.

## How was this patch tested?

`test("DataFrame reuse")`

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #12049 from zsxwing/df-reuse.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/463bac00
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/463bac00
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/463bac00

Branch: refs/heads/master
Commit: 463bac001171622538fc93d2e31d1a617ab562e6
Parents: f77f11c
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Tue Apr 5 11:12:05 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Apr 5 11:12:05 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/ContinuousQueryManager.scala      | 12 +++-
 .../org/apache/spark/sql/DataFrameReader.scala  |  2 +-
 .../scala/org/apache/spark/sql/Dataset.scala    |  6 +-
 .../execution/streaming/StreamExecution.scala   |  8 +--
 .../execution/streaming/StreamingRelation.scala | 27 ++++++++-
 .../spark/sql/execution/streaming/memory.scala  |  6 +-
 .../scala/org/apache/spark/sql/StreamTest.scala |  5 +-
 .../streaming/ContinuousQueryManagerSuite.scala |  6 +-
 .../sql/streaming/FileStreamSourceSuite.scala   | 10 ++--
 .../spark/sql/streaming/StreamSuite.scala       | 62 +++++++++++++++++++-
 10 files changed, 118 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
index 2306df0..d7f71bd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
 import scala.collection.mutable
 
 import org.apache.spark.annotation.Experimental
-import org.apache.spark.sql.execution.streaming.{ContinuousQueryListenerBus, 
Sink, StreamExecution}
+import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef
 import org.apache.spark.sql.util.ContinuousQueryListener
 
@@ -178,11 +178,19 @@ class ContinuousQueryManager(sqlContext: SQLContext) {
         throw new IllegalArgumentException(
           s"Cannot start query with name $name as a query with that name is 
already active")
       }
+      val logicalPlan = df.logicalPlan.transform {
+        case StreamingRelation(dataSource, _, output) =>
+          // Materialize source to avoid creating it in every batch
+          val source = dataSource.createSource()
+          // We still need to use the previous `output` instead of 
`source.schema` as attributes in
+          // "df.logicalPlan" has already used attributes of the previous 
`output`.
+          StreamingExecutionRelation(source, output)
+      }
       val query = new StreamExecution(
         sqlContext,
         name,
         checkpointLocation,
-        df.logicalPlan,
+        logicalPlan,
         sink,
         trigger)
       query.start()

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index a5a6e01..15f2344 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -176,7 +176,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) 
extends Logging {
         userSpecifiedSchema = userSpecifiedSchema,
         className = source,
         options = extraOptions.toMap)
-    Dataset.ofRows(sqlContext, StreamingRelation(dataSource.createSource()))
+    Dataset.ofRows(sqlContext, StreamingRelation(dataSource))
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index db2134b..f472a50 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.execution.command.ExplainCommand
 import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, 
LogicalRelation}
 import org.apache.spark.sql.execution.datasources.json.JacksonGenerator
 import org.apache.spark.sql.execution.python.EvaluatePython
-import org.apache.spark.sql.execution.streaming.StreamingRelation
+import org.apache.spark.sql.execution.streaming.{StreamingExecutionRelation, 
StreamingRelation}
 import org.apache.spark.sql.types._
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
@@ -462,7 +462,9 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @Experimental
-  def isStreaming: Boolean = 
logicalPlan.find(_.isInstanceOf[StreamingRelation]).isDefined
+  def isStreaming: Boolean = logicalPlan.find { n =>
+      n.isInstanceOf[StreamingRelation] || 
n.isInstanceOf[StreamingExecutionRelation]
+    }.isDefined
 
   /**
    * Displays the [[Dataset]] in a tabular form. Strings more than 20 
characters will be truncated,

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 64f8069..3e4acb7 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -43,9 +43,9 @@ import org.apache.spark.util.UninterruptibleThread
  * and the results are committed transactionally to the given [[Sink]].
  */
 class StreamExecution(
-    val sqlContext: SQLContext,
+    override val sqlContext: SQLContext,
     override val name: String,
-    val checkpointRoot: String,
+    checkpointRoot: String,
     private[sql] val logicalPlan: LogicalPlan,
     val sink: Sink,
     val trigger: Trigger) extends ContinuousQuery with Logging {
@@ -72,7 +72,7 @@ class StreamExecution(
 
   /** All stream sources present the query plan. */
   private val sources =
-    logicalPlan.collect { case s: StreamingRelation => s.source }
+    logicalPlan.collect { case s: StreamingExecutionRelation => s.source }
 
   /** A list of unique sources in the query plan. */
   private val uniqueSources = sources.distinct
@@ -295,7 +295,7 @@ class StreamExecution(
     var replacements = new ArrayBuffer[(Attribute, Attribute)]
     // Replace sources in the logical plan with data that has arrived since 
the last batch.
     val withNewSources = logicalPlan transform {
-      case StreamingRelation(source, output) =>
+      case StreamingExecutionRelation(source, output) =>
         newData.get(source).map { data =>
           val newPlan = data.logicalPlan
           assert(output.size == newPlan.output.size,

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
index e35c444..f951dea 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala
@@ -19,16 +19,37 @@ package org.apache.spark.sql.execution.streaming
 
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical.LeafNode
+import org.apache.spark.sql.execution.datasources.DataSource
 
 object StreamingRelation {
-  def apply(source: Source): StreamingRelation =
-    StreamingRelation(source, source.schema.toAttributes)
+  def apply(dataSource: DataSource): StreamingRelation = {
+    val source = dataSource.createSource()
+    StreamingRelation(dataSource, source.toString, source.schema.toAttributes)
+  }
+}
+
+/**
+ * Used to link a streaming [[DataSource]] into a
+ * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only 
used for creating
+ * a streaming [[org.apache.spark.sql.DataFrame]] from 
[[org.apache.spark.sql.DataFrameReader]].
+ * It should be used to create [[Source]] and converted to 
[[StreamingExecutionRelation]] when
+ * passing to [StreamExecution]] to run a query.
+ */
+case class StreamingRelation(dataSource: DataSource, sourceName: String, 
output: Seq[Attribute])
+  extends LeafNode {
+  override def toString: String = sourceName
 }
 
 /**
  * Used to link a streaming [[Source]] of data into a
  * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]].
  */
-case class StreamingRelation(source: Source, output: Seq[Attribute]) extends 
LeafNode {
+case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) 
extends LeafNode {
   override def toString: String = source.toString
 }
+
+object StreamingExecutionRelation {
+  def apply(source: Source): StreamingExecutionRelation = {
+    StreamingExecutionRelation(source, source.schema.toAttributes)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 7d97f81..b652530 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -22,11 +22,9 @@ import java.util.concurrent.atomic.AtomicInteger
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
-import org.apache.spark.SparkEnv
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, RowEncoder}
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.encoderFor
 import org.apache.spark.sql.types.StructType
 
 object MemoryStream {
@@ -45,7 +43,7 @@ object MemoryStream {
 case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
     extends Source with Logging {
   protected val encoder = encoderFor[A]
-  protected val logicalPlan = StreamingRelation(this)
+  protected val logicalPlan = StreamingExecutionRelation(this)
   protected val output = logicalPlan.output
   protected val batches = new ArrayBuffer[Dataset[A]]
 

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
index 3444e56..6ccc99f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala
@@ -36,6 +36,7 @@ import org.scalatest.time.SpanSugar._
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, 
RowEncoder}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.execution.datasources.DataSource
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.util.Utils
 
@@ -66,9 +67,9 @@ import org.apache.spark.util.Utils
 trait StreamTest extends QueryTest with Timeouts {
 
   implicit class RichSource(s: Source) {
-    def toDF(): DataFrame = Dataset.ofRows(sqlContext, StreamingRelation(s))
+    def toDF(): DataFrame = Dataset.ofRows(sqlContext, 
StreamingExecutionRelation(s))
 
-    def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, 
StreamingRelation(s))
+    def toDS[A: Encoder](): Dataset[A] = Dataset(sqlContext, 
StreamingExecutionRelation(s))
   }
 
   /** How long to wait for an active stream to catch up when checking a 
result. */

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
index 29bd3e0..33787de 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala
@@ -29,7 +29,7 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{ContinuousQuery, Dataset, StreamTest}
-import org.apache.spark.sql.execution.streaming.{MemorySink, MemoryStream, 
StreamExecution, StreamingRelation}
+import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.util.Utils
 
@@ -294,8 +294,8 @@ class ContinuousQueryManagerSuite extends StreamTest with 
SharedSQLContext with
       if (withError) {
         logDebug(s"Terminating query ${queryToStop.name} with error")
         queryToStop.asInstanceOf[StreamExecution].logicalPlan.collect {
-          case StreamingRelation(memoryStream, _) =>
-            memoryStream.asInstanceOf[MemoryStream[Int]].addData(0)
+          case StreamingExecutionRelation(source, _) =>
+            source.asInstanceOf[MemoryStream[Int]].addData(0)
         }
       } else {
         logDebug(s"Stopping query ${queryToStop.name}")

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 054f5c9..09daa7f 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -71,8 +71,9 @@ class FileStreamSourceTest extends StreamTest with 
SharedSQLContext {
       }
     reader.stream(path)
       .queryExecution.analyzed
-      .collect { case StreamingRelation(s: FileStreamSource, _) => s }
-      .head
+      .collect { case StreamingRelation(dataSource, _, _) =>
+        dataSource.createSource().asInstanceOf[FileStreamSource]
+      }.head
   }
 
   val valueSchema = new StructType().add("value", StringType)
@@ -96,8 +97,9 @@ class FileStreamSourceSuite extends FileStreamSourceTest with 
SharedSQLContext {
         reader.stream()
       }
     df.queryExecution.analyzed
-      .collect { case StreamingRelation(s: FileStreamSource, _) => s }
-      .head
+      .collect { case StreamingRelation(dataSource, _, _) =>
+        dataSource.createSource().asInstanceOf[FileStreamSource]
+      }.head
       .schema
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/463bac00/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index fbb1792..e4ea555 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -17,9 +17,13 @@
 
 package org.apache.spark.sql.streaming
 
-import org.apache.spark.sql.{Row, StreamTest}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest}
 import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.sources.StreamSourceProvider
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
 
 class StreamSuite extends StreamTest with SharedSQLContext {
 
@@ -81,4 +85,60 @@ class StreamSuite extends StreamTest with SharedSQLContext {
       AddData(inputData, 1, 2, 3, 4),
       CheckAnswer(2, 4))
   }
+
+  test("DataFrame reuse") {
+    def assertDF(df: DataFrame) {
+      withTempDir { outputDir =>
+        withTempDir { checkpointDir =>
+          val query = df.write.format("parquet")
+            .option("checkpointLocation", checkpointDir.getAbsolutePath)
+            .startStream(outputDir.getAbsolutePath)
+          try {
+            query.processAllAvailable()
+            val outputDf = 
sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long]
+            checkDataset[Long](outputDf, (0L to 10L).toArray: _*)
+          } finally {
+            query.stop()
+          }
+        }
+      }
+    }
+
+    val df = 
sqlContext.read.format(classOf[FakeDefaultSource].getName).stream()
+    assertDF(df)
+    assertDF(df)
+  }
+}
+
+/**
+ * A fake StreamSourceProvider thats creates a fake Source that cannot be 
reused.
+ */
+class FakeDefaultSource extends StreamSourceProvider {
+
+  override def createSource(
+      sqlContext: SQLContext,
+      schema: Option[StructType],
+      providerName: String,
+      parameters: Map[String, String]): Source = {
+    // Create a fake Source that emits 0 to 10.
+    new Source {
+      private var offset = -1L
+
+      override def schema: StructType = StructType(StructField("a", 
IntegerType) :: Nil)
+
+      override def getOffset: Option[Offset] = {
+        if (offset >= 10) {
+          None
+        } else {
+          offset += 1
+          Some(LongOffset(offset))
+        }
+      }
+
+      override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
+        val startOffset = 
start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1
+        sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 
1).toDF("a")
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to