Repository: spark
Updated Branches:
  refs/heads/master 5de26194a -> 2dacc81ec


[SPARK-14494][SQL] Fix the race conditions in MemoryStream and MemorySink

## What changes were proposed in this pull request?

Make sure accessing mutable variables in MemoryStream and MemorySink are 
protected by `synchronized`.
This is probably why MemorySinkSuite failed here: 
https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.2/650/testReport/junit/org.apache.spark.sql.streaming/MemorySinkSuite/registering_as_a_table/

## How was this patch tested?
Existing unit tests.

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #12261 from zsxwing/memory-race-condition.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2dacc81e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2dacc81e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2dacc81e

Branch: refs/heads/master
Commit: 2dacc81ec31233e558855a26340ad4662d470387
Parents: 5de2619
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Mon Apr 11 10:42:51 2016 -0700
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Apr 11 10:42:51 2016 -0700

----------------------------------------------------------------------
 .../spark/sql/execution/streaming/memory.scala  | 25 +++++++++++++-------
 1 file changed, 16 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2dacc81e/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index 351ef40..3820968 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.execution.streaming
 
 import java.util.concurrent.atomic.AtomicInteger
+import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
@@ -47,8 +48,11 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
   protected val encoder = encoderFor[A]
   protected val logicalPlan = StreamingExecutionRelation(this)
   protected val output = logicalPlan.output
+
+  @GuardedBy("this")
   protected val batches = new ArrayBuffer[Dataset[A]]
 
+  @GuardedBy("this")
   protected var currentOffset: LongOffset = new LongOffset(-1)
 
   def schema: StructType = encoder.schema
@@ -67,10 +71,10 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
 
   def addData(data: TraversableOnce[A]): Offset = {
     import sqlContext.implicits._
+    val ds = data.toVector.toDS()
+    logDebug(s"Adding ds: $ds")
     this.synchronized {
       currentOffset = currentOffset + 1
-      val ds = data.toVector.toDS()
-      logDebug(s"Adding ds: $ds")
       batches.append(ds)
       currentOffset
     }
@@ -78,10 +82,12 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
 
   override def toString: String = s"MemoryStream[${output.mkString(",")}]"
 
-  override def getOffset: Option[Offset] = if (batches.isEmpty) {
-    None
-  } else {
-    Some(currentOffset)
+  override def getOffset: Option[Offset] = synchronized {
+    if (batches.isEmpty) {
+      None
+    } else {
+      Some(currentOffset)
+    }
   }
 
   /**
@@ -91,7 +97,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
     val startOrdinal =
       
start.map(_.asInstanceOf[LongOffset]).getOrElse(LongOffset(-1)).offset.toInt + 1
     val endOrdinal = end.asInstanceOf[LongOffset].offset.toInt + 1
-    val newBlocks = batches.slice(startOrdinal, endOrdinal)
+    val newBlocks = synchronized { batches.slice(startOrdinal, endOrdinal) }
 
     logDebug(
       s"MemoryBatch [$startOrdinal, $endOrdinal]: 
${newBlocks.flatMap(_.collect()).mkString(", ")}")
@@ -110,6 +116,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: 
SQLContext)
  */
 class MemorySink(val schema: StructType) extends Sink with Logging {
   /** An order list of batches that have been written to this [[Sink]]. */
+  @GuardedBy("this")
   private val batches = new ArrayBuffer[Array[Row]]()
 
   /** Returns all rows that are stored in this [[Sink]]. */
@@ -117,7 +124,7 @@ class MemorySink(val schema: StructType) extends Sink with 
Logging {
     batches.flatten
   }
 
-  def lastBatch: Seq[Row] = batches.last
+  def lastBatch: Seq[Row] = synchronized { batches.last }
 
   def toDebugString: String = synchronized {
     batches.zipWithIndex.map { case (b, i) =>
@@ -128,7 +135,7 @@ class MemorySink(val schema: StructType) extends Sink with 
Logging {
     }.mkString("\n")
   }
 
-  override def addBatch(batchId: Long, data: DataFrame): Unit = {
+  override def addBatch(batchId: Long, data: DataFrame): Unit = synchronized {
     if (batchId == batches.size) {
       logDebug(s"Committing batch $batchId")
       batches.append(data.collect())


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to