Repository: spark
Updated Branches:
  refs/heads/master 529a1d338 -> 27029bc8f


[SPARK-11639][STREAMING][FLAKY-TEST] Implement BlockingWriteAheadLog for 
testing the BatchedWriteAheadLog

Several elements could be drained if the main thread is not fast enough. 
zsxwing warned me about a similar problem, but missed it here :( Submitting the 
fix using a waiter.

cc tdas

Author: Burak Yavuz <brk...@gmail.com>

Closes #9605 from brkyvz/fix-flaky-test.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/27029bc8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/27029bc8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/27029bc8

Branch: refs/heads/master
Commit: 27029bc8f6246514bd0947500c94cf38dc8616c3
Parents: 529a1d3
Author: Burak Yavuz <brk...@gmail.com>
Authored: Wed Nov 11 11:24:55 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Wed Nov 11 11:24:55 2015 -0800

----------------------------------------------------------------------
 .../streaming/util/BatchedWriteAheadLog.scala   |   3 +
 .../streaming/util/WriteAheadLogSuite.scala     | 124 ++++++++++++-------
 2 files changed, 80 insertions(+), 47 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/27029bc8/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
index 9727ed2..6e6ed8d 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala
@@ -182,6 +182,9 @@ private[util] class BatchedWriteAheadLog(val wrappedLog: 
WriteAheadLog, conf: Sp
       buffer.clear()
     }
   }
+
+  /** Method for querying the queue length. Should only be used in tests. */
+  private def getQueueLength(): Int = walWriteQueue.size()
 }
 
 /** Static methods for aggregating and de-aggregating records. */

http://git-wip-us.apache.org/repos/asf/spark/blob/27029bc8/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index e96f4c2..9e13f25 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -18,15 +18,14 @@ package org.apache.spark.streaming.util
 
 import java.io._
 import java.nio.ByteBuffer
-import java.util.concurrent.{ExecutionException, ThreadPoolExecutor}
-import java.util.concurrent.atomic.AtomicInteger
+import java.util.{Iterator => JIterator}
+import java.util.concurrent.ThreadPoolExecutor
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 import scala.concurrent._
 import scala.concurrent.duration._
 import scala.language.{implicitConversions, postfixOps}
-import scala.util.{Failure, Success}
 
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
@@ -37,12 +36,12 @@ import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
 import org.scalatest.concurrent.Eventually
 import org.scalatest.concurrent.Eventually._
-import org.scalatest.{BeforeAndAfterEach, BeforeAndAfter}
+import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter}
 import org.scalatest.mock.MockitoSugar
 
 import org.apache.spark.streaming.scheduler._
 import org.apache.spark.util.{ThreadUtils, ManualClock, Utils}
-import org.apache.spark.{SparkException, SparkConf, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkFunSuite}
 
 /** Common tests for WriteAheadLogs that we would like to test with different 
configurations. */
 abstract class CommonWriteAheadLogTests(
@@ -315,7 +314,11 @@ class FileBasedWriteAheadLogWithFileCloseAfterWriteSuite
 class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests(
     allowBatching = true,
     closeFileAfterWrite = false,
-    "BatchedWriteAheadLog") with MockitoSugar with BeforeAndAfterEach with 
Eventually {
+    "BatchedWriteAheadLog")
+  with MockitoSugar
+  with BeforeAndAfterEach
+  with Eventually
+  with PrivateMethodTester {
 
   import BatchedWriteAheadLog._
   import WriteAheadLogSuite._
@@ -326,6 +329,8 @@ class BatchedWriteAheadLogSuite extends 
CommonWriteAheadLogTests(
   private var walBatchingExecutionContext: ExecutionContextExecutorService = _
   private val sparkConf = new SparkConf()
 
+  private val queueLength = PrivateMethod[Int]('getQueueLength)
+
   override def beforeEach(): Unit = {
     wal = mock[WriteAheadLog]
     walHandle = mock[WriteAheadLogRecordHandle]
@@ -366,7 +371,7 @@ class BatchedWriteAheadLogSuite extends 
CommonWriteAheadLogTests(
   }
 
   // we make the write requests in separate threads so that we don't block the 
test thread
-  private def promiseWriteEvent(wal: WriteAheadLog, event: String, time: 
Long): Promise[Unit] = {
+  private def writeAsync(wal: WriteAheadLog, event: String, time: Long): 
Promise[Unit] = {
     val p = Promise[Unit]()
     p.completeWith(Future {
       val v = wal.write(event, time)
@@ -375,28 +380,9 @@ class BatchedWriteAheadLogSuite extends 
CommonWriteAheadLogTests(
     p
   }
 
-  /**
-   * In order to block the writes on the writer thread, we mock the write 
method, and block it
-   * for some time with a promise.
-   */
-  private def writeBlockingPromise(wal: WriteAheadLog): Promise[Any] = {
-    // we would like to block the write so that we can queue requests
-    val promise = Promise[Any]()
-    when(wal.write(any[ByteBuffer], any[Long])).thenAnswer(
-      new Answer[WriteAheadLogRecordHandle] {
-        override def answer(invocation: InvocationOnMock): 
WriteAheadLogRecordHandle = {
-          Await.ready(promise.future, 4.seconds)
-          walHandle
-        }
-      }
-    )
-    promise
-  }
-
   test("BatchedWriteAheadLog - name log with aggregated entries with the 
timestamp of last entry") {
-    val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
-    // block the write so that we can batch some records
-    val promise = writeBlockingPromise(wal)
+    val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
+    val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
 
     val event1 = "hello"
     val event2 = "world"
@@ -406,21 +392,27 @@ class BatchedWriteAheadLogSuite extends 
CommonWriteAheadLogTests(
 
     // The queue.take() immediately takes the 3, and there is nothing left in 
the queue at that
     // moment. Then the promise blocks the writing of 3. The rest get queued.
-    promiseWriteEvent(batchedWal, event1, 3L)
-    // rest of the records will be batched while it takes 3 to get written
-    promiseWriteEvent(batchedWal, event2, 5L)
-    promiseWriteEvent(batchedWal, event3, 8L)
-    promiseWriteEvent(batchedWal, event4, 12L)
-    promiseWriteEvent(batchedWal, event5, 10L)
+    writeAsync(batchedWal, event1, 3L)
+    eventually(timeout(1 second)) {
+      assert(blockingWal.isBlocked)
+      assert(batchedWal.invokePrivate(queueLength()) === 0)
+    }
+    // rest of the records will be batched while it takes time for 3 to get 
written
+    writeAsync(batchedWal, event2, 5L)
+    writeAsync(batchedWal, event3, 8L)
+    writeAsync(batchedWal, event4, 12L)
+    writeAsync(batchedWal, event5, 10L)
     eventually(timeout(1 second)) {
       assert(walBatchingThreadPool.getActiveCount === 5)
+      assert(batchedWal.invokePrivate(queueLength()) === 4)
     }
-    promise.success(true)
+    blockingWal.allowWrite()
 
     val buffer1 = wrapArrayArrayByte(Array(event1))
     val buffer2 = wrapArrayArrayByte(Array(event2, event3, event4, event5))
 
     eventually(timeout(1 second)) {
+      assert(batchedWal.invokePrivate(queueLength()) === 0)
       verify(wal, times(1)).write(meq(buffer1), meq(3L))
       // the file name should be the timestamp of the last record, as events 
should be naturally
       // in order of timestamp, and we need the last element.
@@ -437,27 +429,32 @@ class BatchedWriteAheadLogSuite extends 
CommonWriteAheadLogTests(
   }
 
   test("BatchedWriteAheadLog - fail everything in queue during shutdown") {
-    val batchedWal = new BatchedWriteAheadLog(wal, sparkConf)
+    val blockingWal = new BlockingWriteAheadLog(wal, walHandle)
+    val batchedWal = new BatchedWriteAheadLog(blockingWal, sparkConf)
 
-    // block the write so that we can batch some records
-    writeBlockingPromise(wal)
-
-    val event1 = ("hello", 3L)
-    val event2 = ("world", 5L)
-    val event3 = ("this", 8L)
-    val event4 = ("is", 9L)
-    val event5 = ("doge", 10L)
+    val event1 = "hello"
+    val event2 = "world"
+    val event3 = "this"
 
     // The queue.take() immediately takes the 3, and there is nothing left in 
the queue at that
     // moment. Then the promise blocks the writing of 3. The rest get queued.
-    val writePromises = Seq(event1, event2, event3, event4, event5).map { 
event =>
-      promiseWriteEvent(batchedWal, event._1, event._2)
+    val promise1 = writeAsync(batchedWal, event1, 3L)
+    eventually(timeout(1 second)) {
+      assert(blockingWal.isBlocked)
+      assert(batchedWal.invokePrivate(queueLength()) === 0)
     }
+    // rest of the records will be batched while it takes time for 3 to get 
written
+    val promise2 = writeAsync(batchedWal, event2, 5L)
+    val promise3 = writeAsync(batchedWal, event3, 8L)
 
     eventually(timeout(1 second)) {
-      assert(walBatchingThreadPool.getActiveCount === 5)
+      assert(walBatchingThreadPool.getActiveCount === 3)
+      assert(blockingWal.isBlocked)
+      assert(batchedWal.invokePrivate(queueLength()) === 2) // event1 is being 
written
     }
 
+    val writePromises = Seq(promise1, promise2, promise3)
+
     batchedWal.close()
     eventually(timeout(1 second)) {
       assert(writePromises.forall(_.isCompleted))
@@ -641,4 +638,37 @@ object WriteAheadLogSuite {
   def wrapArrayArrayByte[T](records: Array[T]): ByteBuffer = {
     
ByteBuffer.wrap(Utils.serialize[Array[Array[Byte]]](records.map(Utils.serialize[T])))
   }
+
+  /**
+   * A wrapper WriteAheadLog that blocks the write function to allow batching 
with the
+   * BatchedWriteAheadLog.
+   */
+  class BlockingWriteAheadLog(
+      wal: WriteAheadLog,
+      handle: WriteAheadLogRecordHandle) extends WriteAheadLog {
+    @volatile private var isWriteCalled: Boolean = false
+    @volatile private var blockWrite: Boolean = true
+
+    override def write(record: ByteBuffer, time: Long): 
WriteAheadLogRecordHandle = {
+      isWriteCalled = true
+      eventually(Eventually.timeout(2 second)) {
+        assert(!blockWrite)
+      }
+      wal.write(record, time)
+      isWriteCalled = false
+      handle
+    }
+    override def read(segment: WriteAheadLogRecordHandle): ByteBuffer = 
wal.read(segment)
+    override def readAll(): JIterator[ByteBuffer] = wal.readAll()
+    override def clean(threshTime: Long, waitForCompletion: Boolean): Unit = {
+      wal.clean(threshTime, waitForCompletion)
+    }
+    override def close(): Unit = wal.close()
+
+    def allowWrite(): Unit = {
+      blockWrite = false
+    }
+
+    def isBlocked: Boolean = isWriteCalled
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to