Repository: spark
Updated Branches:
  refs/heads/branch-1.6 ff3497542 -> 19ea30d82


[SPARK-11845][STREAMING][TEST] Added unit test to verify TrackStateRDD is 
correctly checkpointed

To make sure that all lineage is correctly truncated for TrackStateRDD when 
checkpointed.

Author: Tathagata Das <tathagata.das1...@gmail.com>

Closes #9831 from tdas/SPARK-11845.

(cherry picked from commit b2cecb80ece59a1c086d4ae7aeebef445a4e7299)
Signed-off-by: Andrew Or <and...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/19ea30d8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/19ea30d8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/19ea30d8

Branch: refs/heads/branch-1.6
Commit: 19ea30d829a569a9263c1cd205687054a7b03e30
Parents: ff34975
Author: Tathagata Das <tathagata.das1...@gmail.com>
Authored: Thu Nov 19 16:50:08 2015 -0800
Committer: Andrew Or <and...@databricks.com>
Committed: Thu Nov 19 16:50:15 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/CheckpointSuite.scala      | 411 ++++++++++---------
 .../streaming/rdd/TrackStateRDDSuite.scala      |  60 ++-
 2 files changed, 267 insertions(+), 204 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/19ea30d8/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala 
b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 119e5fc..ab23326 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -21,17 +21,223 @@ import java.io.File
 
 import scala.reflect.ClassTag
 
+import org.apache.spark.CheckpointSuite._
 import org.apache.spark.rdd._
 import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
 import org.apache.spark.util.Utils
 
+trait RDDCheckpointTester { self: SparkFunSuite =>
+
+  protected val partitioner = new HashPartitioner(2)
+
+  private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
+
+  /** Implementations of this trait must implement this method */
+  protected def sparkContext: SparkContext
+
+  /**
+   * Test checkpointing of the RDD generated by the given operation. It tests 
whether the
+   * serialized size of the RDD is reduce after checkpointing or not. This 
function should be called
+   * on all RDDs that have a parent RDD (i.e., do not call on 
ParallelCollection, BlockRDD, etc.).
+   *
+   * @param op an operation to run on the RDD
+   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise 
use local checkpoints
+   * @param collectFunc a function for collecting the values in the RDD, in 
case there are
+   *                    non-comparable types like arrays that we want to 
convert to something
+   *                    that supports ==
+   */
+  protected def testRDD[U: ClassTag](
+      op: (RDD[Int]) => RDD[U],
+      reliableCheckpoint: Boolean,
+      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
+    // Generate the final RDD using given RDD operation
+    val baseRDD = generateFatRDD()
+    val operatedRDD = op(baseRDD)
+    val parentRDD = operatedRDD.dependencies.headOption.orNull
+    val rddType = operatedRDD.getClass.getSimpleName
+    val numPartitions = operatedRDD.partitions.length
+
+    // Force initialization of all the data structures in RDDs
+    // Without this, serializing the RDD will give a wrong estimate of the 
size of the RDD
+    initializeRdd(operatedRDD)
+
+    val partitionsBeforeCheckpoint = operatedRDD.partitions
+
+    // Find serialized sizes before and after the checkpoint
+    logInfo("RDD before checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
+    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = 
getSerializedSizes(operatedRDD)
+    checkpoint(operatedRDD, reliableCheckpoint)
+    val result = collectFunc(operatedRDD)
+    operatedRDD.collect() // force re-initialization of post-checkpoint lazy 
variables
+    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = 
getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
+
+    // Test whether the checkpoint file has been created
+    if (reliableCheckpoint) {
+      assert(
+        
collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) 
=== result)
+    }
+
+    // Test whether dependencies have been changed from its earlier parent RDD
+    assert(operatedRDD.dependencies.head.rdd != parentRDD)
+
+    // Test whether the partitions have been changed from its earlier 
partitions
+    assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
+
+    // Test whether the partitions have been changed to the new Hadoop 
partitions
+    assert(operatedRDD.partitions.toList === 
operatedRDD.checkpointData.get.getPartitions.toList)
+
+    // Test whether the number of partitions is same as before
+    assert(operatedRDD.partitions.length === numPartitions)
+
+    // Test whether the data in the checkpointed RDD is same as original
+    assert(collectFunc(operatedRDD) === result)
+
+    // Test whether serialized size of the RDD has reduced.
+    logInfo("Size of " + rddType +
+      " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+    assert(
+      rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+      "Size of " + rddType + " did not reduce after checkpointing " +
+        " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+    )
+  }
+
+  /**
+   * Test whether checkpointing of the parent of the generated RDD also
+   * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its 
parent
+   * RDDs partitions. So even if the parent RDD is checkpointed and its 
partitions changed,
+   * the generated RDD will remember the partitions and therefore potentially 
the whole lineage.
+   * This function should be called only those RDD whose partitions refer to 
parent RDD's
+   * partitions (i.e., do not call it on simple RDD like MappedRDD).
+   *
+   * @param op an operation to run on the RDD
+   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise 
use local checkpoints
+   * @param collectFunc a function for collecting the values in the RDD, in 
case there are
+   *                    non-comparable types like arrays that we want to 
convert to something
+   *                    that supports ==
+   */
+  protected def testRDDPartitions[U: ClassTag](
+      op: (RDD[Int]) => RDD[U],
+      reliableCheckpoint: Boolean,
+      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
+    // Generate the final RDD using given RDD operation
+    val baseRDD = generateFatRDD()
+    val operatedRDD = op(baseRDD)
+    val parentRDDs = operatedRDD.dependencies.map(_.rdd)
+    val rddType = operatedRDD.getClass.getSimpleName
+
+    // Force initialization of all the data structures in RDDs
+    // Without this, serializing the RDD will give a wrong estimate of the 
size of the RDD
+    initializeRdd(operatedRDD)
+
+    // Find serialized sizes before and after the checkpoint
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
+    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = 
getSerializedSizes(operatedRDD)
+    // checkpoint the parent RDD, not the generated one
+    parentRDDs.foreach { rdd =>
+      checkpoint(rdd, reliableCheckpoint)
+    }
+    val result = collectFunc(operatedRDD) // force checkpointing
+    operatedRDD.collect() // force re-initialization of post-checkpoint lazy 
variables
+    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = 
getSerializedSizes(operatedRDD)
+    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
+
+    // Test whether the data in the checkpointed RDD is same as original
+    assert(collectFunc(operatedRDD) === result)
+
+    // Test whether serialized size of the partitions has reduced
+    logInfo("Size of partitions of " + rddType +
+      " [" + partitionSizeBeforeCheckpoint + " --> " + 
partitionSizeAfterCheckpoint + "]")
+    assert(
+      partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
+      "Size of " + rddType + " partitions did not reduce after checkpointing 
parent RDDs" +
+        " [" + partitionSizeBeforeCheckpoint + " --> " + 
partitionSizeAfterCheckpoint + "]"
+    )
+  }
+
+  /**
+   * Get serialized sizes of the RDD and its partitions, in order to test 
whether the size shrinks
+   * upon checkpointing. Ignores the checkpointData field, which may grow when 
we checkpoint.
+   */
+  private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
+    val rddSize = Utils.serialize(rdd).size
+    val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
+    val rddPartitionSize = Utils.serialize(rdd.partitions).size
+    val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
+
+    // Print detailed size, helps in debugging
+    logInfo("Serialized sizes of " + rdd +
+      ": RDD = " + rddSize +
+      ", RDD checkpoint data = " + rddCpDataSize +
+      ", RDD partitions = " + rddPartitionSize +
+      ", RDD dependencies = " + rddDependenciesSize
+    )
+    // this makes sure that serializing the RDD's checkpoint data does not
+    // serialize the whole RDD as well
+    assert(
+      rddSize > rddCpDataSize,
+      "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than 
the " +
+        "whole RDD with checkpoint data (" + rddSize + ")"
+    )
+    (rddSize - rddCpDataSize, rddPartitionSize)
+  }
+
+  /**
+   * Serialize and deserialize an object. This is useful to verify the objects
+   * contents after deserialization (e.g., the contents of an RDD split after
+   * it is sent to a slave along with a task)
+   */
+  protected def serializeDeserialize[T](obj: T): T = {
+    val bytes = Utils.serialize(obj)
+    Utils.deserialize[T](bytes)
+  }
+
+  /**
+   * Recursively force the initialization of the all members of an RDD and it 
parents.
+   */
+  private def initializeRdd(rdd: RDD[_]): Unit = {
+    rdd.partitions // forces the initialization of the partitions
+    rdd.dependencies.map(_.rdd).foreach(initializeRdd)
+  }
+
+  /** Checkpoint the RDD either locally or reliably. */
+  protected def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
+    if (reliableCheckpoint) {
+      rdd.checkpoint()
+    } else {
+      rdd.localCheckpoint()
+    }
+  }
+
+  /** Run a test twice, once for local checkpointing and once for reliable 
checkpointing. */
+  protected def runTest(name: String)(body: Boolean => Unit): Unit = {
+    test(name + " [reliable checkpoint]")(body(true))
+    test(name + " [local checkpoint]")(body(false))
+  }
+
+  /**
+   * Generate an RDD such that both the RDD and its partitions have large size.
+   */
+  protected def generateFatRDD(): RDD[Int] = {
+    new FatRDD(sparkContext.makeRDD(1 to 100, 4)).map(x => x)
+  }
+
+  /**
+   * Generate an pair RDD (with partitioner) such that both the RDD and its 
partitions
+   * have large size.
+   */
+  protected def generateFatPairRDD(): RDD[(Int, Int)] = {
+    new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x 
=> x)
+  }
+}
+
 /**
  * Test suite for end-to-end checkpointing functionality.
  * This tests both reliable checkpoints and local checkpoints.
  */
-class CheckpointSuite extends SparkFunSuite with LocalSparkContext with 
Logging {
+class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with 
LocalSparkContext {
   private var checkpointDir: File = _
-  private val partitioner = new HashPartitioner(2)
 
   override def beforeEach(): Unit = {
     super.beforeEach()
@@ -46,6 +252,8 @@ class CheckpointSuite extends SparkFunSuite with 
LocalSparkContext with Logging
     Utils.deleteRecursively(checkpointDir)
   }
 
+  override def sparkContext: SparkContext = sc
+
   runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
     val parCollection = sc.makeRDD(1 to 4)
     val flatMappedRDD = parCollection.flatMap(x => 1 to x)
@@ -250,204 +458,6 @@ class CheckpointSuite extends SparkFunSuite with 
LocalSparkContext with Logging
     assert(rdd.isCheckpointedAndMaterialized === true)
     assert(rdd.partitions.size === 0)
   }
-
-  // Utility test methods
-
-  /** Checkpoint the RDD either locally or reliably. */
-  private def checkpoint(rdd: RDD[_], reliableCheckpoint: Boolean): Unit = {
-    if (reliableCheckpoint) {
-      rdd.checkpoint()
-    } else {
-      rdd.localCheckpoint()
-    }
-  }
-
-  /** Run a test twice, once for local checkpointing and once for reliable 
checkpointing. */
-  private def runTest(name: String)(body: Boolean => Unit): Unit = {
-    test(name + " [reliable checkpoint]")(body(true))
-    test(name + " [local checkpoint]")(body(false))
-  }
-
-  private def defaultCollectFunc[T](rdd: RDD[T]): Any = rdd.collect()
-
-  /**
-   * Test checkpointing of the RDD generated by the given operation. It tests 
whether the
-   * serialized size of the RDD is reduce after checkpointing or not. This 
function should be called
-   * on all RDDs that have a parent RDD (i.e., do not call on 
ParallelCollection, BlockRDD, etc.).
-   *
-   * @param op an operation to run on the RDD
-   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise 
use local checkpoints
-   * @param collectFunc a function for collecting the values in the RDD, in 
case there are
-   *   non-comparable types like arrays that we want to convert to something 
that supports ==
-   */
-  private def testRDD[U: ClassTag](
-      op: (RDD[Int]) => RDD[U],
-      reliableCheckpoint: Boolean,
-      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
-    // Generate the final RDD using given RDD operation
-    val baseRDD = generateFatRDD()
-    val operatedRDD = op(baseRDD)
-    val parentRDD = operatedRDD.dependencies.headOption.orNull
-    val rddType = operatedRDD.getClass.getSimpleName
-    val numPartitions = operatedRDD.partitions.length
-
-    // Force initialization of all the data structures in RDDs
-    // Without this, serializing the RDD will give a wrong estimate of the 
size of the RDD
-    initializeRdd(operatedRDD)
-
-    val partitionsBeforeCheckpoint = operatedRDD.partitions
-
-    // Find serialized sizes before and after the checkpoint
-    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
-    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = 
getSerializedSizes(operatedRDD)
-    checkpoint(operatedRDD, reliableCheckpoint)
-    val result = collectFunc(operatedRDD)
-    operatedRDD.collect() // force re-initialization of post-checkpoint lazy 
variables
-    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = 
getSerializedSizes(operatedRDD)
-    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
-
-    // Test whether the checkpoint file has been created
-    if (reliableCheckpoint) {
-      
assert(collectFunc(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === 
result)
-    }
-
-    // Test whether dependencies have been changed from its earlier parent RDD
-    assert(operatedRDD.dependencies.head.rdd != parentRDD)
-
-    // Test whether the partitions have been changed from its earlier 
partitions
-    assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
-
-    // Test whether the partitions have been changed to the new Hadoop 
partitions
-    assert(operatedRDD.partitions.toList === 
operatedRDD.checkpointData.get.getPartitions.toList)
-
-    // Test whether the number of partitions is same as before
-    assert(operatedRDD.partitions.length === numPartitions)
-
-    // Test whether the data in the checkpointed RDD is same as original
-    assert(collectFunc(operatedRDD) === result)
-
-    // Test whether serialized size of the RDD has reduced.
-    logInfo("Size of " + rddType +
-      " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
-    assert(
-      rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
-      "Size of " + rddType + " did not reduce after checkpointing " +
-        " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
-    )
-  }
-
-  /**
-   * Test whether checkpointing of the parent of the generated RDD also
-   * truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its 
parent
-   * RDDs partitions. So even if the parent RDD is checkpointed and its 
partitions changed,
-   * the generated RDD will remember the partitions and therefore potentially 
the whole lineage.
-   * This function should be called only those RDD whose partitions refer to 
parent RDD's
-   * partitions (i.e., do not call it on simple RDD like MappedRDD).
-   *
-   * @param op an operation to run on the RDD
-   * @param reliableCheckpoint if true, use reliable checkpoints, otherwise 
use local checkpoints
-   * @param collectFunc a function for collecting the values in the RDD, in 
case there are
-   *   non-comparable types like arrays that we want to convert to something 
that supports ==
-   */
-  private def testRDDPartitions[U: ClassTag](
-      op: (RDD[Int]) => RDD[U],
-      reliableCheckpoint: Boolean,
-      collectFunc: RDD[U] => Any = defaultCollectFunc[U] _): Unit = {
-    // Generate the final RDD using given RDD operation
-    val baseRDD = generateFatRDD()
-    val operatedRDD = op(baseRDD)
-    val parentRDDs = operatedRDD.dependencies.map(_.rdd)
-    val rddType = operatedRDD.getClass.getSimpleName
-
-    // Force initialization of all the data structures in RDDs
-    // Without this, serializing the RDD will give a wrong estimate of the 
size of the RDD
-    initializeRdd(operatedRDD)
-
-    // Find serialized sizes before and after the checkpoint
-    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
-    val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = 
getSerializedSizes(operatedRDD)
-    // checkpoint the parent RDD, not the generated one
-    parentRDDs.foreach { rdd =>
-      checkpoint(rdd, reliableCheckpoint)
-    }
-    val result = collectFunc(operatedRDD)  // force checkpointing
-    operatedRDD.collect() // force re-initialization of post-checkpoint lazy 
variables
-    val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = 
getSerializedSizes(operatedRDD)
-    logInfo("RDD after checkpoint: " + operatedRDD + "\n" + 
operatedRDD.toDebugString)
-
-    // Test whether the data in the checkpointed RDD is same as original
-    assert(collectFunc(operatedRDD) === result)
-
-    // Test whether serialized size of the partitions has reduced
-    logInfo("Size of partitions of " + rddType +
-      " [" + partitionSizeBeforeCheckpoint + " --> " + 
partitionSizeAfterCheckpoint + "]")
-    assert(
-      partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
-      "Size of " + rddType + " partitions did not reduce after checkpointing 
parent RDDs" +
-        " [" + partitionSizeBeforeCheckpoint + " --> " + 
partitionSizeAfterCheckpoint + "]"
-    )
-  }
-
-  /**
-   * Generate an RDD such that both the RDD and its partitions have large size.
-   */
-  private def generateFatRDD(): RDD[Int] = {
-    new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
-  }
-
-  /**
-   * Generate an pair RDD (with partitioner) such that both the RDD and its 
partitions
-   * have large size.
-   */
-  private def generateFatPairRDD(): RDD[(Int, Int)] = {
-    new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
-  }
-
-  /**
-   * Get serialized sizes of the RDD and its partitions, in order to test 
whether the size shrinks
-   * upon checkpointing. Ignores the checkpointData field, which may grow when 
we checkpoint.
-   */
-  private def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
-    val rddSize = Utils.serialize(rdd).size
-    val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
-    val rddPartitionSize = Utils.serialize(rdd.partitions).size
-    val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
-
-    // Print detailed size, helps in debugging
-    logInfo("Serialized sizes of " + rdd +
-      ": RDD = " + rddSize +
-      ", RDD checkpoint data = " + rddCpDataSize +
-      ", RDD partitions = " + rddPartitionSize +
-      ", RDD dependencies = " + rddDependenciesSize
-    )
-    // this makes sure that serializing the RDD's checkpoint data does not
-    // serialize the whole RDD as well
-    assert(
-      rddSize > rddCpDataSize,
-      "RDD's checkpoint data (" + rddCpDataSize  + ") is equal or larger than 
the " +
-        "whole RDD with checkpoint data (" + rddSize + ")"
-    )
-    (rddSize - rddCpDataSize, rddPartitionSize)
-  }
-
-  /**
-   * Serialize and deserialize an object. This is useful to verify the objects
-   * contents after deserialization (e.g., the contents of an RDD split after
-   * it is sent to a slave along with a task)
-   */
-  private def serializeDeserialize[T](obj: T): T = {
-    val bytes = Utils.serialize(obj)
-    Utils.deserialize[T](bytes)
-  }
-
-  /**
-   * Recursively force the initialization of the all members of an RDD and it 
parents.
-   */
-  private def initializeRdd(rdd: RDD[_]): Unit = {
-    rdd.partitions // forces the
-    rdd.dependencies.map(_.rdd).foreach(initializeRdd)
-  }
-
 }
 
 /** RDD partition that has large serialized size. */
@@ -494,5 +504,4 @@ object CheckpointSuite {
       part
     ).asInstanceOf[RDD[(K, Array[Iterable[V]])]]
   }
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/19ea30d8/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
index 19ef5a1..0feb3af 100644
--- 
a/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
+++ 
b/streaming/src/test/scala/org/apache/spark/streaming/rdd/TrackStateRDDSuite.scala
@@ -17,31 +17,40 @@
 
 package org.apache.spark.streaming.rdd
 
+import java.io.File
+
 import scala.collection.mutable.ArrayBuffer
 import scala.reflect.ClassTag
 
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
-import org.apache.spark.streaming.{Time, State}
-import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, 
SparkFunSuite}
+import org.apache.spark.streaming.{State, Time}
+import org.apache.spark.util.Utils
 
-class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+class TrackStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with 
BeforeAndAfterAll {
 
   private var sc: SparkContext = null
+  private var checkpointDir: File = _
 
   override def beforeAll(): Unit = {
     sc = new SparkContext(
       new SparkConf().setMaster("local").setAppName("TrackStateRDDSuite"))
+    checkpointDir = Utils.createTempDir()
+    sc.setCheckpointDir(checkpointDir.toString)
   }
 
   override def afterAll(): Unit = {
     if (sc != null) {
       sc.stop()
     }
+    Utils.deleteRecursively(checkpointDir)
   }
 
+  override def sparkContext: SparkContext = sc
+
   test("creation from pair RDD") {
     val data = Seq((1, "1"), (2, "2"), (3, "3"))
     val partitioner = new HashPartitioner(10)
@@ -278,6 +287,51 @@ class TrackStateRDDSuite extends SparkFunSuite with 
BeforeAndAfterAll {
       rdd7, Seq(("k3", 2)), Set())
   }
 
+  test("checkpointing") {
+    /**
+     * This tests whether the TrackStateRDD correctly truncates any references 
to its parent RDDs -
+     * the data RDD and the parent TrackStateRDD.
+     */
+    def rddCollectFunc(rdd: RDD[TrackStateRDDRecord[Int, Int, Int]])
+      : Set[(List[(Int, Int, Long)], List[Int])] = {
+      rdd.map { record => (record.stateMap.getAll().toList, 
record.emittedRecords.toList) }
+         .collect.toSet
+    }
+
+    /** Generate TrackStateRDD with data RDD having a long lineage */
+    def makeStateRDDWithLongLineageDataRDD(longLineageRDD: RDD[Int])
+      : TrackStateRDD[Int, Int, Int, Int] = {
+      TrackStateRDD.createFromPairRDD(longLineageRDD.map { _ -> 1}, 
partitioner, Time(0))
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, 
rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageDataRDD, reliableCheckpoint = true, 
rddCollectFunc _)
+
+    /** Generate TrackStateRDD with parent state RDD having a long lineage */
+    def makeStateRDDWithLongLineageParenttateRDD(
+        longLineageRDD: RDD[Int]): TrackStateRDD[Int, Int, Int, Int] = {
+
+      // Create a TrackStateRDD that has a long lineage using the data RDD 
with a long lineage
+      val stateRDDWithLongLineage = 
makeStateRDDWithLongLineageDataRDD(longLineageRDD)
+
+      // Create a new TrackStateRDD, with the lineage lineage TrackStateRDD as 
the parent
+      new TrackStateRDD[Int, Int, Int, Int](
+        stateRDDWithLongLineage,
+        stateRDDWithLongLineage.sparkContext.emptyRDD[(Int, 
Int)].partitionBy(partitioner),
+        (time: Time, key: Int, value: Option[Int], state: State[Int]) => None,
+        Time(10),
+        None
+      )
+    }
+
+    testRDD(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, 
rddCollectFunc _)
+    testRDDPartitions(
+      makeStateRDDWithLongLineageParenttateRDD, reliableCheckpoint = true, 
rddCollectFunc _)
+  }
+
   /** Assert whether the `trackStateByKey` operation generates expected 
results */
   private def assertOperation[K: ClassTag, V: ClassTag, S: ClassTag, T: 
ClassTag](
       testStateRDD: TrackStateRDD[K, V, S, T],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to