Repository: spark
Updated Branches:
  refs/heads/branch-1.6 bdd8a6bd4 -> 9e80db7c7


[SPARK-11359][STREAMING][KINESIS] Checkpoint to DynamoDB even when new data 
doesn't come in

Currently, the checkpoints to DynamoDB occur only when new data comes in, as we 
update the clock for the checkpointState. This PR makes the checkpoint a 
scheduled execution based on the `checkpointInterval`.

Author: Burak Yavuz <brk...@gmail.com>

Closes #9421 from brkyvz/kinesis-checkpoint.

(cherry picked from commit a3a7c9103e136035d65a5564f9eb0fa04727c4f3)
Signed-off-by: Tathagata Das <tathagata.das1...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9e80db7c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9e80db7c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9e80db7c

Branch: refs/heads/branch-1.6
Commit: 9e80db7c7d1600691a5c012610e3f28f35210d46
Parents: bdd8a6b
Author: Burak Yavuz <brk...@gmail.com>
Authored: Mon Nov 9 14:39:18 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Mon Nov 9 14:39:30 2015 -0800

----------------------------------------------------------------------
 .../kinesis/KinesisCheckpointState.scala        |  54 -------
 .../streaming/kinesis/KinesisCheckpointer.scala | 133 ++++++++++++++++
 .../streaming/kinesis/KinesisReceiver.scala     |  38 ++++-
 .../kinesis/KinesisRecordProcessor.scala        |  59 ++-----
 .../kinesis/KinesisCheckpointerSuite.scala      | 152 +++++++++++++++++++
 .../kinesis/KinesisReceiverSuite.scala          |  96 +++---------
 6 files changed, 349 insertions(+), 183 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
deleted file mode 100644
index 83a4537..0000000
--- 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ /dev/null
@@ -1,54 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.streaming.kinesis
-
-import org.apache.spark.Logging
-import org.apache.spark.streaming.Duration
-import org.apache.spark.util.{Clock, ManualClock, SystemClock}
-
-/**
- * This is a helper class for managing checkpoint clocks.
- *
- * @param checkpointInterval
- * @param currentClock.  Default to current SystemClock if none is passed in 
(mocking purposes)
- */
-private[kinesis] class KinesisCheckpointState(
-    checkpointInterval: Duration,
-    currentClock: Clock = new SystemClock())
-  extends Logging {
-
-  /* Initialize the checkpoint clock using the given currentClock + 
checkpointInterval millis */
-  val checkpointClock = new ManualClock()
-  checkpointClock.setTime(currentClock.getTimeMillis() + 
checkpointInterval.milliseconds)
-
-  /**
-   * Check if it's time to checkpoint based on the current time and the 
derived time
-   *   for the next checkpoint
-   *
-   * @return true if it's time to checkpoint
-   */
-  def shouldCheckpoint(): Boolean = {
-    new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis()
-  }
-
-  /**
-   * Advance the checkpoint clock by the checkpoint interval.
-   */
-  def advanceCheckpoint(): Unit = {
-    checkpointClock.advance(checkpointInterval.milliseconds)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
new file mode 100644
index 0000000..1ca6d43
--- /dev/null
+++ 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointer.scala
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.streaming.kinesis
+
+import java.util.concurrent._
+
+import scala.util.control.NonFatal
+
+import 
com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
+
+import org.apache.spark.Logging
+import org.apache.spark.streaming.Duration
+import org.apache.spark.streaming.util.RecurringTimer
+import org.apache.spark.util.{Clock, SystemClock, ThreadUtils}
+
+/**
+ * This is a helper class for managing Kinesis checkpointing.
+ *
+ * @param receiver The receiver that keeps track of which sequence numbers we 
can checkpoint
+ * @param checkpointInterval How frequently we will checkpoint to DynamoDB
+ * @param workerId Worker Id of KCL worker for logging purposes
+ * @param clock In order to use ManualClocks for the purpose of testing
+ */
+private[kinesis] class KinesisCheckpointer(
+    receiver: KinesisReceiver[_],
+    checkpointInterval: Duration,
+    workerId: String,
+    clock: Clock = new SystemClock) extends Logging {
+
+  // a map from shardId's to checkpointers
+  private val checkpointers = new ConcurrentHashMap[String, 
IRecordProcessorCheckpointer]()
+
+  private val lastCheckpointedSeqNums = new ConcurrentHashMap[String, String]()
+
+  private val checkpointerThread: RecurringTimer = startCheckpointerThread()
+
+  /** Update the checkpointer instance to the most recent one for the given 
shardId. */
+  def setCheckpointer(shardId: String, checkpointer: 
IRecordProcessorCheckpointer): Unit = {
+    checkpointers.put(shardId, checkpointer)
+  }
+
+  /**
+   * Stop tracking the specified shardId.
+   *
+   * If a checkpointer is provided, e.g. on IRecordProcessor.shutdown 
[[ShutdownReason.TERMINATE]],
+   * we will use that to make the final checkpoint. If `null` is provided, we 
will not make the
+   * checkpoint, e.g. in case of [[ShutdownReason.ZOMBIE]].
+   */
+  def removeCheckpointer(shardId: String, checkpointer: 
IRecordProcessorCheckpointer): Unit = {
+    synchronized {
+      checkpointers.remove(shardId)
+      checkpoint(shardId, checkpointer)
+    }
+  }
+
+  /** Perform the checkpoint. */
+  private def checkpoint(shardId: String, checkpointer: 
IRecordProcessorCheckpointer): Unit = {
+    try {
+      if (checkpointer != null) {
+        receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum =>
+          val lastSeqNum = lastCheckpointedSeqNums.get(shardId)
+          // Kinesis sequence numbers are monotonically increasing strings, 
therefore we can do
+          // safely do the string comparison
+          if (lastSeqNum == null || latestSeqNum > lastSeqNum) {
+            /* Perform the checkpoint */
+            
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 
100)
+            logDebug(s"Checkpoint:  WorkerId $workerId completed checkpoint at 
sequence number" +
+              s" $latestSeqNum for shardId $shardId")
+            lastCheckpointedSeqNums.put(shardId, latestSeqNum)
+          }
+        }
+      } else {
+        logDebug(s"Checkpointing skipped for shardId $shardId. Checkpointer 
not set.")
+      }
+    } catch {
+      case NonFatal(e) =>
+        logWarning(s"Failed to checkpoint shardId $shardId to DynamoDB.", e)
+    }
+  }
+
+  /** Checkpoint the latest saved sequence numbers for all active shardId's. */
+  private def checkpointAll(): Unit = synchronized {
+    // if this method throws an exception, then the scheduled task will not 
run again
+    try {
+      val shardIds = checkpointers.keys()
+      while (shardIds.hasMoreElements) {
+        val shardId = shardIds.nextElement()
+        checkpoint(shardId, checkpointers.get(shardId))
+      }
+    } catch {
+      case NonFatal(e) =>
+        logWarning("Failed to checkpoint to DynamoDB.", e)
+    }
+  }
+
+  /**
+   * Start the checkpointer thread with the given checkpoint duration.
+   */
+  private def startCheckpointerThread(): RecurringTimer = {
+    val period = checkpointInterval.milliseconds
+    val threadName = s"Kinesis Checkpointer - Worker $workerId"
+    val timer = new RecurringTimer(clock, period, _ => checkpointAll(), 
threadName)
+    timer.start()
+    logDebug(s"Started checkpointer thread: $threadName")
+    timer
+  }
+
+  /**
+   * Shutdown the checkpointer. Should be called on the onStop of the Receiver.
+   */
+  def shutdown(): Unit = {
+    // the recurring timer checkpoints for us one last time.
+    checkpointerThread.stop(interruptTimer = false)
+    checkpointers.clear()
+    lastCheckpointedSeqNums.clear()
+    logInfo("Successfully shutdown Kinesis Checkpointer.")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
index 134d627..50993f1 100644
--- 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
+++ 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala
@@ -23,7 +23,7 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import com.amazonaws.auth.{AWSCredentials, AWSCredentialsProvider, 
DefaultAWSCredentialsProviderChain}
-import 
com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessor, 
IRecordProcessorFactory}
+import 
com.amazonaws.services.kinesis.clientlibrary.interfaces.{IRecordProcessorCheckpointer,
 IRecordProcessor, IRecordProcessorFactory}
 import 
com.amazonaws.services.kinesis.clientlibrary.lib.worker.{InitialPositionInStream,
 KinesisClientLibConfiguration, Worker}
 import com.amazonaws.services.kinesis.model.Record
 
@@ -31,8 +31,7 @@ import org.apache.spark.storage.{StorageLevel, StreamBlockId}
 import org.apache.spark.streaming.Duration
 import org.apache.spark.streaming.receiver.{BlockGenerator, 
BlockGeneratorListener, Receiver}
 import org.apache.spark.util.Utils
-import org.apache.spark.{Logging, SparkEnv}
-
+import org.apache.spark.Logging
 
 private[kinesis]
 case class SerializableAWSCredentials(accessKeyId: String, secretKey: String)
@@ -128,6 +127,11 @@ private[kinesis] class KinesisReceiver[T](
     with mutable.SynchronizedMap[StreamBlockId, SequenceNumberRanges]
 
   /**
+   * The centralized kinesisCheckpointer that checkpoints based on the given 
checkpointInterval.
+   */
+  @volatile private var kinesisCheckpointer: KinesisCheckpointer = null
+
+  /**
    * Latest sequence number ranges that have been stored successfully.
    * This is used for checkpointing through KCL */
   private val shardIdToLatestStoredSeqNum = new mutable.HashMap[String, String]
@@ -141,6 +145,7 @@ private[kinesis] class KinesisReceiver[T](
 
     workerId = Utils.localHostName() + ":" + UUID.randomUUID()
 
+    kinesisCheckpointer = new KinesisCheckpointer(receiver, 
checkpointInterval, workerId)
     // KCL config instance
     val awsCredProvider = resolveAWSCredentialsProvider()
     val kinesisClientLibConfiguration =
@@ -157,8 +162,8 @@ private[kinesis] class KinesisReceiver[T](
     *  We're using our custom KinesisRecordProcessor in this case.
     */
     val recordProcessorFactory = new IRecordProcessorFactory {
-      override def createProcessor: IRecordProcessor = new 
KinesisRecordProcessor(receiver,
-        workerId, new KinesisCheckpointState(checkpointInterval))
+      override def createProcessor: IRecordProcessor =
+        new KinesisRecordProcessor(receiver, workerId)
     }
 
     worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration)
@@ -198,6 +203,10 @@ private[kinesis] class KinesisReceiver[T](
       logInfo(s"Stopped receiver for workerId $workerId")
     }
     workerId = null
+    if (kinesisCheckpointer != null) {
+      kinesisCheckpointer.shutdown()
+      kinesisCheckpointer = null
+    }
   }
 
   /** Add records of the given shard to the current block being generated */
@@ -217,6 +226,25 @@ private[kinesis] class KinesisReceiver[T](
   }
 
   /**
+   * Set the checkpointer that will be used to checkpoint sequence numbers to 
DynamoDB for the
+   * given shardId.
+   */
+  def setCheckpointer(shardId: String, checkpointer: 
IRecordProcessorCheckpointer): Unit = {
+    assert(kinesisCheckpointer != null, "Kinesis Checkpointer not 
initialized!")
+    kinesisCheckpointer.setCheckpointer(shardId, checkpointer)
+  }
+
+  /**
+   * Remove the checkpointer for the given shardId. The provided checkpointer 
will be used to
+   * checkpoint one last time for the given shard. If `checkpointer` is 
`null`, then we will not
+   * checkpoint.
+   */
+  def removeCheckpointer(shardId: String, checkpointer: 
IRecordProcessorCheckpointer): Unit = {
+    assert(kinesisCheckpointer != null, "Kinesis Checkpointer not 
initialized!")
+    kinesisCheckpointer.removeCheckpointer(shardId, checkpointer)
+  }
+
+  /**
    * Remember the range of sequence numbers that was added to the currently 
active block.
    * Internally, this is synchronized with `finalizeRangesForCurrentBlock()`.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 1d51787..e381ffa 100644
--- 
a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ 
b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -27,26 +27,23 @@ import 
com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
 import com.amazonaws.services.kinesis.model.Record
 
 import org.apache.spark.Logging
+import org.apache.spark.streaming.Duration
 
 /**
  * Kinesis-specific implementation of the Kinesis Client Library (KCL) 
IRecordProcessor.
  * This implementation operates on the Array[Byte] from the KinesisReceiver.
  * The Kinesis Worker creates an instance of this KinesisRecordProcessor for 
each
- *   shard in the Kinesis stream upon startup.  This is normally done in 
separate threads,
- *   but the KCLs within the KinesisReceivers will balance themselves out if 
you create
- *   multiple Receivers.
+ * shard in the Kinesis stream upon startup.  This is normally done in 
separate threads,
+ * but the KCLs within the KinesisReceivers will balance themselves out if you 
create
+ * multiple Receivers.
  *
  * @param receiver Kinesis receiver
  * @param workerId for logging purposes
- * @param checkpointState represents the checkpoint state including the next 
checkpoint time.
- *   It's injected here for mocking purposes.
  */
-private[kinesis] class KinesisRecordProcessor[T](
-    receiver: KinesisReceiver[T],
-    workerId: String,
-    checkpointState: KinesisCheckpointState) extends IRecordProcessor with 
Logging {
+private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], 
workerId: String)
+  extends IRecordProcessor with Logging {
 
-  // shardId to be populated during initialize()
+  // shardId populated during initialize()
   @volatile
   private var shardId: String = _
 
@@ -74,34 +71,7 @@ private[kinesis] class KinesisRecordProcessor[T](
       try {
         receiver.addRecords(shardId, batch)
         logDebug(s"Stored: Worker $workerId stored ${batch.size} records for 
shardId $shardId")
-
-        /*
-         *
-         * Checkpoint the sequence number of the last record successfully 
stored.
-         * Note that in this current implementation, the checkpointing occurs 
only when after
-         * checkpointIntervalMillis from the last checkpoint, AND when there 
is new record
-         * to process. This leads to the checkpointing lagging behind what 
records have been
-         * stored by the receiver. Ofcourse, this can lead records processed 
more than once,
-         * under failures and restarts.
-         *
-         * TODO: Instead of checkpointing here, run a separate timer task to 
perform
-         * checkpointing so that it checkpoints in a timely manner independent 
of whether
-         * new records are available or not.
-         */
-        if (checkpointState.shouldCheckpoint()) {
-          receiver.getLatestSeqNumToCheckpoint(shardId).foreach { latestSeqNum 
=>
-            /* Perform the checkpoint */
-            
KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(latestSeqNum), 4, 
100)
-
-            /* Update the next checkpoint time */
-            checkpointState.advanceCheckpoint()
-
-            logDebug(s"Checkpoint:  WorkerId $workerId completed checkpoint of 
${batch.size}" +
-              s" records for shardId $shardId")
-            logDebug(s"Checkpoint:  Next checkpoint is at " +
-              s" ${checkpointState.checkpointClock.getTimeMillis()} for 
shardId $shardId")
-          }
-        }
+        receiver.setCheckpointer(shardId, checkpointer)
       } catch {
         case NonFatal(e) => {
           /*
@@ -142,23 +112,18 @@ private[kinesis] class KinesisRecordProcessor[T](
        * It's now OK to read from the new shards that resulted from a 
resharding event.
        */
       case ShutdownReason.TERMINATE =>
-        val latestSeqNumToCheckpointOption = 
receiver.getLatestSeqNumToCheckpoint(shardId)
-        if (latestSeqNumToCheckpointOption.nonEmpty) {
-          KinesisRecordProcessor.retryRandom(
-            checkpointer.checkpoint(latestSeqNumToCheckpointOption.get), 4, 
100)
-        }
+        receiver.removeCheckpointer(shardId, checkpointer)
 
       /*
-       * ZOMBIE Use Case.  NoOp.
+       * ZOMBIE Use Case or Unknown reason.  NoOp.
        * No checkpoint because other workers may have taken over and already 
started processing
        *    the same records.
        * This may lead to records being processed more than once.
        */
-      case ShutdownReason.ZOMBIE =>
-
-      /* Unknown reason.  NoOp */
       case _ =>
+        receiver.removeCheckpointer(shardId, null) // return null so that we 
don't checkpoint
     }
+
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
 
b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
new file mode 100644
index 0000000..645e64a
--- /dev/null
+++ 
b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointerSuite.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kinesis
+
+import java.util.concurrent.{TimeoutException, ExecutorService}
+
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import 
com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach}
+import org.scalatest.concurrent.Eventually
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.ManualClock
+
+class KinesisCheckpointerSuite extends TestSuiteBase
+  with MockitoSugar
+  with BeforeAndAfterEach
+  with PrivateMethodTester
+  with Eventually {
+
+  private val workerId = "dummyWorkerId"
+  private val shardId = "dummyShardId"
+  private val seqNum = "123"
+  private val otherSeqNum = "245"
+  private val checkpointInterval = Duration(10)
+  private val someSeqNum = Some(seqNum)
+  private val someOtherSeqNum = Some(otherSeqNum)
+
+  private var receiverMock: KinesisReceiver[Array[Byte]] = _
+  private var checkpointerMock: IRecordProcessorCheckpointer = _
+  private var kinesisCheckpointer: KinesisCheckpointer = _
+  private var clock: ManualClock = _
+
+  private val checkpoint = PrivateMethod[Unit]('checkpoint)
+
+  override def beforeEach(): Unit = {
+    receiverMock = mock[KinesisReceiver[Array[Byte]]]
+    checkpointerMock = mock[IRecordProcessorCheckpointer]
+    clock = new ManualClock()
+    kinesisCheckpointer = new KinesisCheckpointer(receiverMock, 
checkpointInterval, workerId, clock)
+  }
+
+  test("checkpoint is not called twice for the same sequence number") {
+    
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+    kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+    kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+    verify(checkpointerMock, times(1)).checkpoint(anyString())
+  }
+
+  test("checkpoint is called after sequence number increases") {
+    when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+      .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+    kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+    kinesisCheckpointer.invokePrivate(checkpoint(shardId, checkpointerMock))
+
+    verify(checkpointerMock, times(1)).checkpoint(seqNum)
+    verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+  }
+
+  test("should checkpoint if we have exceeded the checkpoint interval") {
+    when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+      .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+
+    kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+    clock.advance(5 * checkpointInterval.milliseconds)
+
+    eventually(timeout(1 second)) {
+      verify(checkpointerMock, times(1)).checkpoint(seqNum)
+      verify(checkpointerMock, times(1)).checkpoint(otherSeqNum)
+    }
+  }
+
+  test("shouldn't checkpoint if we have not exceeded the checkpoint interval") 
{
+    
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+    kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+    clock.advance(checkpointInterval.milliseconds / 2)
+
+    verify(checkpointerMock, never()).checkpoint(anyString())
+  }
+
+  test("should not checkpoint for the same sequence number") {
+    
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+    kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+
+    clock.advance(checkpointInterval.milliseconds * 5)
+    eventually(timeout(1 second)) {
+      verify(checkpointerMock, atMost(1)).checkpoint(anyString())
+    }
+  }
+
+  test("removing checkpointer checkpoints one last time") {
+    
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
+
+    kinesisCheckpointer.removeCheckpointer(shardId, checkpointerMock)
+    verify(checkpointerMock, times(1)).checkpoint(anyString())
+  }
+
+  test("if checkpointing is going on, wait until finished before removing and 
checkpointing") {
+    when(receiverMock.getLatestSeqNumToCheckpoint(shardId))
+      .thenReturn(someSeqNum).thenReturn(someOtherSeqNum)
+    when(checkpointerMock.checkpoint(anyString)).thenAnswer(new Answer[Unit] {
+      override def answer(invocations: InvocationOnMock): Unit = {
+        clock.waitTillTime(clock.getTimeMillis() + 
checkpointInterval.milliseconds / 2)
+      }
+    })
+
+    kinesisCheckpointer.setCheckpointer(shardId, checkpointerMock)
+    clock.advance(checkpointInterval.milliseconds)
+    eventually(timeout(1 second)) {
+      verify(checkpointerMock, times(1)).checkpoint(anyString())
+    }
+    // don't block test thread
+    val f = Future(kinesisCheckpointer.removeCheckpointer(shardId, 
checkpointerMock))(
+      ExecutionContext.global)
+
+    intercept[TimeoutException] {
+      Await.ready(f, 50 millis)
+    }
+
+    clock.advance(checkpointInterval.milliseconds / 2)
+    eventually(timeout(1 second)) {
+      verify(checkpointerMock, times(2)).checkpoint(anyString())
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9e80db7c/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
----------------------------------------------------------------------
diff --git 
a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
 
b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 17ab444..e5c70db 100644
--- 
a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ 
b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -25,12 +25,13 @@ import 
com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorC
 import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
 import com.amazonaws.services.kinesis.model.Record
 import org.mockito.Matchers._
+import org.mockito.Matchers.{eq => meq}
 import org.mockito.Mockito._
 import org.scalatest.mock.MockitoSugar
 import org.scalatest.{BeforeAndAfter, Matchers}
 
-import org.apache.spark.streaming.{Milliseconds, TestSuiteBase}
-import org.apache.spark.util.{Clock, ManualClock, Utils}
+import org.apache.spark.streaming.{Duration, TestSuiteBase}
+import org.apache.spark.util.Utils
 
 /**
  * Suite of Kinesis streaming receiver tests focusing mostly on the 
KinesisRecordProcessor
@@ -44,6 +45,7 @@ class KinesisReceiverSuite extends TestSuiteBase with 
Matchers with BeforeAndAft
   val workerId = "dummyWorkerId"
   val shardId = "dummyShardId"
   val seqNum = "dummySeqNum"
+  val checkpointInterval = Duration(10)
   val someSeqNum = Some(seqNum)
 
   val record1 = new Record()
@@ -54,24 +56,10 @@ class KinesisReceiverSuite extends TestSuiteBase with 
Matchers with BeforeAndAft
 
   var receiverMock: KinesisReceiver[Array[Byte]] = _
   var checkpointerMock: IRecordProcessorCheckpointer = _
-  var checkpointClockMock: ManualClock = _
-  var checkpointStateMock: KinesisCheckpointState = _
-  var currentClockMock: Clock = _
 
   override def beforeFunction(): Unit = {
     receiverMock = mock[KinesisReceiver[Array[Byte]]]
     checkpointerMock = mock[IRecordProcessorCheckpointer]
-    checkpointClockMock = mock[ManualClock]
-    checkpointStateMock = mock[KinesisCheckpointState]
-    currentClockMock = mock[Clock]
-  }
-
-  override def afterFunction(): Unit = {
-    super.afterFunction()
-    // Since this suite was originally written using EasyMock, add this to 
preserve the old
-    // mocking semantics (see SPARK-5735 for more details)
-    verifyNoMoreInteractions(receiverMock, checkpointerMock, 
checkpointClockMock,
-      checkpointStateMock, currentClockMock)
   }
 
   test("check serializability of SerializableAWSCredentials") {
@@ -79,113 +67,67 @@ class KinesisReceiverSuite extends TestSuiteBase with 
Matchers with BeforeAndAft
       Utils.serialize(new SerializableAWSCredentials("x", "y")))
   }
 
-  test("process records including store and checkpoint") {
+  test("process records including store and set checkpointer") {
     when(receiverMock.isStopped()).thenReturn(false)
-    
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
-    when(checkpointStateMock.shouldCheckpoint()).thenReturn(true)
 
-    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, 
checkpointStateMock)
+    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
     recordProcessor.initialize(shardId)
     recordProcessor.processRecords(batch, checkpointerMock)
 
     verify(receiverMock, times(1)).isStopped()
     verify(receiverMock, times(1)).addRecords(shardId, batch)
-    verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
-    verify(checkpointStateMock, times(1)).shouldCheckpoint()
-    verify(checkpointerMock, times(1)).checkpoint(anyString)
-    verify(checkpointStateMock, times(1)).advanceCheckpoint()
+    verify(receiverMock, times(1)).setCheckpointer(shardId, checkpointerMock)
   }
 
-  test("shouldn't store and checkpoint when receiver is stopped") {
+  test("shouldn't store and update checkpointer when receiver is stopped") {
     when(receiverMock.isStopped()).thenReturn(true)
 
-    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, 
checkpointStateMock)
+    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
     recordProcessor.processRecords(batch, checkpointerMock)
 
     verify(receiverMock, times(1)).isStopped()
     verify(receiverMock, never).addRecords(anyString, 
anyListOf(classOf[Record]))
-    verify(checkpointerMock, never).checkpoint(anyString)
+    verify(receiverMock, never).setCheckpointer(anyString, 
meq(checkpointerMock))
   }
 
-  test("shouldn't checkpoint when exception occurs during store") {
+  test("shouldn't update checkpointer when exception occurs during store") {
     when(receiverMock.isStopped()).thenReturn(false)
     when(
       receiverMock.addRecords(anyString, anyListOf(classOf[Record]))
     ).thenThrow(new RuntimeException())
 
     intercept[RuntimeException] {
-      val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, 
checkpointStateMock)
+      val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
       recordProcessor.initialize(shardId)
       recordProcessor.processRecords(batch, checkpointerMock)
     }
 
     verify(receiverMock, times(1)).isStopped()
     verify(receiverMock, times(1)).addRecords(shardId, batch)
-    verify(checkpointerMock, never).checkpoint(anyString)
-  }
-
-  test("should set checkpoint time to currentTime + checkpoint interval upon 
instantiation") {
-    when(currentClockMock.getTimeMillis()).thenReturn(0)
-
-    val checkpointIntervalMillis = 10
-    val checkpointState =
-      new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), 
currentClockMock)
-    assert(checkpointState.checkpointClock.getTimeMillis() == 
checkpointIntervalMillis)
-
-    verify(currentClockMock, times(1)).getTimeMillis()
-  }
-
-  test("should checkpoint if we have exceeded the checkpoint interval") {
-    when(currentClockMock.getTimeMillis()).thenReturn(0)
-
-    val checkpointState = new 
KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock)
-    assert(checkpointState.shouldCheckpoint())
-
-    verify(currentClockMock, times(1)).getTimeMillis()
-  }
-
-  test("shouldn't checkpoint if we have not exceeded the checkpoint interval") 
{
-    when(currentClockMock.getTimeMillis()).thenReturn(0)
-
-    val checkpointState = new 
KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock)
-    assert(!checkpointState.shouldCheckpoint())
-
-    verify(currentClockMock, times(1)).getTimeMillis()
-  }
-
-  test("should add to time when advancing checkpoint") {
-    when(currentClockMock.getTimeMillis()).thenReturn(0)
-
-    val checkpointIntervalMillis = 10
-    val checkpointState =
-      new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), 
currentClockMock)
-    assert(checkpointState.checkpointClock.getTimeMillis() == 
checkpointIntervalMillis)
-    checkpointState.advanceCheckpoint()
-    assert(checkpointState.checkpointClock.getTimeMillis() == (2 * 
checkpointIntervalMillis))
-
-    verify(currentClockMock, times(1)).getTimeMillis()
+    verify(receiverMock, never).setCheckpointer(anyString, 
meq(checkpointerMock))
   }
 
   test("shutdown should checkpoint if the reason is TERMINATE") {
     
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
 
-    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, 
checkpointStateMock)
+    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
     recordProcessor.initialize(shardId)
     recordProcessor.shutdown(checkpointerMock, ShutdownReason.TERMINATE)
 
-    verify(receiverMock, times(1)).getLatestSeqNumToCheckpoint(shardId)
-    verify(checkpointerMock, times(1)).checkpoint(anyString)
+    verify(receiverMock, times(1)).removeCheckpointer(meq(shardId), 
meq(checkpointerMock))
   }
 
+
   test("shutdown should not checkpoint if the reason is something other than 
TERMINATE") {
     
when(receiverMock.getLatestSeqNumToCheckpoint(shardId)).thenReturn(someSeqNum)
 
-    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, 
checkpointStateMock)
+    val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId)
     recordProcessor.initialize(shardId)
     recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
     recordProcessor.shutdown(checkpointerMock, null)
 
-    verify(checkpointerMock, never).checkpoint(anyString)
+    verify(receiverMock, times(2)).removeCheckpointer(meq(shardId),
+      meq[IRecordProcessorCheckpointer](null))
   }
 
   test("retry success on first attempt") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to