Repository: spark
Updated Branches:
  refs/heads/master 2e0f3579f -> 7394e7ade


[SPARK-7120] [SPARK-7121] Closure cleaner nesting + documentation + tests

Note: ~600 lines of this is test code, and ~100 lines documentation.

**[SPARK-7121]** ClosureCleaner does not handle nested closures properly. For 
instance, in SparkContext, I tried to do the following:
```
def scope[T](body: => T): T = body // no-op
def myCoolMethod(path: String): RDD[String] = scope {
  parallelize(1 to 10).map { _ => path }
}
```
and I got an exception complaining that SparkContext is not serializable. The 
issue here is that the inner closure is getting its path from the outer closure 
(the scope), but the outer closure references the SparkContext object itself to 
get the `parallelize` method.

Note, however, that the inner closure doesn't actually need the SparkContext; 
it just needs a field from the outer closure. If we modify ClosureCleaner to 
clean the outer closure recursively using only the fields accessed by the inner 
closure, then we can serialize the inner closure.

**[SPARK-7120]** Also, the other thing is that this file is one of the least 
understood, partly because it is very low level and is written a long time ago. 
This patch attempts to change that by adding the missing documentation.

This is blocking my effort on a separate task #5729.

Author: Andrew Or <and...@databricks.com>

Closes #5685 from andrewor14/closure-cleaner and squashes the following commits:

cd46230 [Andrew Or] Revert a small change that affected streaming
0bbe77f [Andrew Or] Fix style
ea874bc [Andrew Or] Fix tests
26c5072 [Andrew Or] Address comments
16fbcfd [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
closure-cleaner
26c7aba [Andrew Or] Revert "In sc.runJob, actually clean the inner closure"
6f75784 [Andrew Or] Revert "Guard against NPE if CC is used outside of an 
application"
e909a42 [Andrew Or] Guard against NPE if CC is used outside of an application
3998168 [Andrew Or] In sc.runJob, actually clean the inner closure
9187066 [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
closure-cleaner
d889950 [Andrew Or] Revert "Bypass SerializationDebugger for now (SPARK-7180)"
9419efe [Andrew Or] Bypass SerializationDebugger for now (SPARK-7180)
6d4d3f1 [Andrew Or] Fix scala style?
4aab379 [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
closure-cleaner
e45e904 [Andrew Or] More minor updates (wording, renaming etc.)
8b71cdb [Andrew Or] Update a few comments
eb127e5 [Andrew Or] Use private method tester for a few things
a3aa465 [Andrew Or] Add more tests for individual closure cleaner operations
e672170 [Andrew Or] Guard against potential infinite cycles in method visitor
6d36f38 [Andrew Or] Fix closure cleaner visibility
2106f12 [Andrew Or] Merge branch 'master' of github.com:apache/spark into 
closure-cleaner
263593d [Andrew Or] Finalize tests
06fd668 [Andrew Or] Make closure cleaning idempotent
a4866e3 [Andrew Or] Add tests (still WIP)
438c68f [Andrew Or] Minor changes
2390a60 [Andrew Or] Feature flag this new behavior
86f7823 [Andrew Or] Implement transitive cleaning + add missing documentation


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7394e7ad
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7394e7ad
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7394e7ad

Branch: refs/heads/master
Commit: 7394e7adeb03df159978f1d10061d9ec6a913968
Parents: 2e0f357
Author: Andrew Or <and...@databricks.com>
Authored: Fri May 1 23:57:58 2015 -0700
Committer: Patrick Wendell <patr...@databricks.com>
Committed: Fri May 1 23:57:58 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/util/ClosureCleaner.scala  | 305 ++++++++--
 .../apache/spark/util/ClosureCleanerSuite.scala |  13 +-
 .../spark/util/ClosureCleanerSuite2.scala       | 571 +++++++++++++++++++
 3 files changed, 831 insertions(+), 58 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7394e7ad/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala 
b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index e3f52f6..4ac0382 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -19,17 +19,20 @@ package org.apache.spark.util
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 
-import scala.collection.mutable.Map
-import scala.collection.mutable.Set
+import scala.collection.mutable.{Map, Set}
 
 import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.{ClassReader, 
ClassVisitor, MethodVisitor, Type}
 import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._
 
 import org.apache.spark.{Logging, SparkEnv, SparkException}
 
+/**
+ * A cleaner that renders closures serializable if they can be done so safely.
+ */
 private[spark] object ClosureCleaner extends Logging {
+
   // Get an ASM class reader for a given class from the JAR that loaded it
-  private def getClassReader(cls: Class[_]): ClassReader = {
+  private[util] def getClassReader(cls: Class[_]): ClassReader = {
     // Copy data over, before delegating to ClassReader - else we can run out 
of open file handles.
     val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
     val resourceStream = cls.getResourceAsStream(className)
@@ -55,10 +58,14 @@ private[spark] object ClosureCleaner extends Logging {
   private def getOuterClasses(obj: AnyRef): List[Class[_]] = {
     for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
       f.setAccessible(true)
-      if (isClosure(f.getType)) {
-        return f.getType :: getOuterClasses(f.get(obj))
-      } else {
-        return f.getType :: Nil // Stop at the first $outer that is not a 
closure
+      val outer = f.get(obj)
+      // The outer pointer may be null if we have cleaned this closure before
+      if (outer != null) {
+        if (isClosure(f.getType)) {
+          return f.getType :: getOuterClasses(outer)
+        } else {
+          return f.getType :: Nil // Stop at the first $outer that is not a 
closure
+        }
       }
     }
     Nil
@@ -68,16 +75,23 @@ private[spark] object ClosureCleaner extends Logging {
   private def getOuterObjects(obj: AnyRef): List[AnyRef] = {
     for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") {
       f.setAccessible(true)
-      if (isClosure(f.getType)) {
-        return f.get(obj) :: getOuterObjects(f.get(obj))
-      } else {
-        return f.get(obj) :: Nil // Stop at the first $outer that is not a 
closure
+      val outer = f.get(obj)
+      // The outer pointer may be null if we have cleaned this closure before
+      if (outer != null) {
+        if (isClosure(f.getType)) {
+          return outer :: getOuterObjects(outer)
+        } else {
+          return outer :: Nil // Stop at the first $outer that is not a closure
+        }
       }
     }
     Nil
   }
 
-  private def getInnerClasses(obj: AnyRef): List[Class[_]] = {
+  /**
+   * Return a list of classes that represent closures enclosed in the given 
closure object.
+   */
+  private def getInnerClosureClasses(obj: AnyRef): List[Class[_]] = {
     val seen = Set[Class[_]](obj.getClass)
     var stack = List[Class[_]](obj.getClass)
     while (!stack.isEmpty) {
@@ -90,7 +104,7 @@ private[spark] object ClosureCleaner extends Logging {
         stack = cls :: stack
       }
     }
-    return (seen - obj.getClass).toList
+    (seen - obj.getClass).toList
   }
 
   private def createNullValue(cls: Class[_]): AnyRef = {
@@ -101,21 +115,124 @@ private[spark] object ClosureCleaner extends Logging {
     }
   }
 
-  def clean(func: AnyRef, checkSerializable: Boolean = true) {
+  /**
+   * Clean the given closure in place.
+   *
+   * More specifically, this renders the given closure serializable as long as 
it does not
+   * explicitly reference unserializable objects.
+   *
+   * @param closure the closure to clean
+   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   */
+  def clean(
+      closure: AnyRef,
+      checkSerializable: Boolean = true,
+      cleanTransitively: Boolean = true): Unit = {
+    clean(closure, checkSerializable, cleanTransitively, Map.empty)
+  }
+
+  /**
+   * Helper method to clean the given closure in place.
+   *
+   * The mechanism is to traverse the hierarchy of enclosing closures and null 
out any
+   * references along the way that are not actually used by the starting 
closure, but are
+   * nevertheless included in the compiled anonymous classes. Note that it is 
unsafe to
+   * simply mutate the enclosing closures in place, as other code paths may 
depend on them.
+   * Instead, we clone each enclosing closure and set the parent pointers 
accordingly.
+   *
+   * By default, closures are cleaned transitively. This means we detect 
whether enclosing
+   * objects are actually referenced by the starting one, either directly or 
transitively,
+   * and, if not, sever these closures from the hierarchy. In other words, in 
addition to
+   * nulling out unused field references, we also null out any parent pointers 
that refer
+   * to enclosing objects not actually needed by the starting closure. We 
determine
+   * transitivity by tracing through the tree of all methods ultimately 
invoked by the
+   * inner closure and record all the fields referenced in the process.
+   *
+   * For instance, transitive cleaning is necessary in the following scenario:
+   *
+   *   class SomethingNotSerializable {
+   *     def someValue = 1
+   *     def scope(name: String)(body: => Unit) = body
+   *     def someMethod(): Unit = scope("one") {
+   *       def x = someValue
+   *       def y = 2
+   *       scope("two") { println(y + 1) }
+   *     }
+   *   }
+   *
+   * In this example, scope "two" is not serializable because it references 
scope "one", which
+   * references SomethingNotSerializable. Note that, however, the body of 
scope "two" does not
+   * actually depend on SomethingNotSerializable. This means we can safely 
null out the parent
+   * pointer of a cloned scope "one" and set it the parent of scope "two", 
such that scope "two"
+   * no longer references SomethingNotSerializable transitively.
+   *
+   * @param func the starting closure to clean
+   * @param checkSerializable whether to verify that the closure is 
serializable after cleaning
+   * @param cleanTransitively whether to clean enclosing closures transitively
+   * @param accessedFields a map from a class to a set of its fields that are 
accessed by
+   *                       the starting closure
+   */
+  private def clean(
+      func: AnyRef,
+      checkSerializable: Boolean,
+      cleanTransitively: Boolean,
+      accessedFields: Map[Class[_], Set[String]]): Unit = {
+
+    // TODO: clean all inner closures first. This requires us to find the 
inner objects.
     // TODO: cache outerClasses / innerClasses / accessedFields
+
+    if (func == null) {
+      return
+    }
+
+    logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}}) +++")
+
+    // A list of classes that represents closures enclosed in the given one
+    val innerClasses = getInnerClosureClasses(func)
+
+    // A list of enclosing objects and their respective classes, from 
innermost to outermost
+    // An outer object at a given index is of type outer class at the same 
index
     val outerClasses = getOuterClasses(func)
-    val innerClasses = getInnerClasses(func)
     val outerObjects = getOuterObjects(func)
 
-    val accessedFields = Map[Class[_], Set[String]]()
-    
+    // For logging purposes only
+    val declaredFields = func.getClass.getDeclaredFields
+    val declaredMethods = func.getClass.getDeclaredMethods
+
+    logDebug(" + declared fields: " + declaredFields.size)
+    declaredFields.foreach { f => logDebug("     " + f) }
+    logDebug(" + declared methods: " + declaredMethods.size)
+    declaredMethods.foreach { m => logDebug("     " + m) }
+    logDebug(" + inner classes: " + innerClasses.size)
+    innerClasses.foreach { c => logDebug("     " + c.getName) }
+    logDebug(" + outer classes: " + outerClasses.size)
+    outerClasses.foreach { c => logDebug("     " + c.getName) }
+    logDebug(" + outer objects: " + outerObjects.size)
+    outerObjects.foreach { o => logDebug("     " + o) }
+
+    // Fail fast if we detect return statements in closures
     getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
-    
-    for (cls <- outerClasses)
-      accessedFields(cls) = Set[String]()
-    for (cls <- func.getClass :: innerClasses)
-      getClassReader(cls).accept(new FieldAccessFinder(accessedFields), 0)
-    // logInfo("accessedFields: " + accessedFields)
+
+    // If accessed fields is not populated yet, we assume that
+    // the closure we are trying to clean is the starting one
+    if (accessedFields.isEmpty) {
+      logDebug(s" + populating accessed fields because this is the starting 
closure")
+      // Initialize accessed fields with the outer classes first
+      // This step is needed to associate the fields to the correct classes 
later
+      for (cls <- outerClasses) {
+        accessedFields(cls) = Set[String]()
+      }
+      // Populate accessed fields by visiting all fields and methods accessed 
by this and
+      // all of its inner closures. If transitive cleaning is enabled, this 
may recursively
+      // visits methods that belong to other classes in search of transitively 
referenced fields.
+      for (cls <- func.getClass :: innerClasses) {
+        getClassReader(cls).accept(new FieldAccessFinder(accessedFields, 
cleanTransitively), 0)
+      }
+    }
+
+    logDebug(s" + fields accessed by starting closure: " + accessedFields.size)
+    accessedFields.foreach { f => logDebug("     " + f) }
 
     val inInterpreter = {
       try {
@@ -126,34 +243,68 @@ private[spark] object ClosureCleaner extends Logging {
       }
     }
 
+    // List of outer (class, object) pairs, ordered from outermost to innermost
+    // Note that all outer objects but the outermost one (first one in this 
list) must be closures
     var outerPairs: List[(Class[_], AnyRef)] = (outerClasses zip 
outerObjects).reverse
-    var outer: AnyRef = null
+    var parent: AnyRef = null
     if (outerPairs.size > 0 && !isClosure(outerPairs.head._1)) {
       // The closure is ultimately nested inside a class; keep the object of 
that
       // class without cloning it since we don't want to clone the user's 
objects.
-      outer = outerPairs.head._2
+      // Note that we still need to keep around the outermost object itself 
because
+      // we need it to clone its child closure later (see below).
+      logDebug(s" + outermost object is not a closure, so do not clone it: 
${outerPairs.head}")
+      parent = outerPairs.head._2 // e.g. SparkContext
       outerPairs = outerPairs.tail
+    } else if (outerPairs.size > 0) {
+      logDebug(s" + outermost object is a closure, so we just keep it: 
${outerPairs.head}")
+    } else {
+      logDebug(" + there are no enclosing objects!")
     }
+
     // Clone the closure objects themselves, nulling out any fields that are 
not
     // used in the closure we're working on or any of its inner closures.
     for ((cls, obj) <- outerPairs) {
-      outer = instantiateClass(cls, outer, inInterpreter)
+      logDebug(s" + cloning the object $obj of class ${cls.getName}")
+      // We null out these unused references by cloning each object and then 
filling in all
+      // required fields from the original object. We need the parent here 
because the Java
+      // language specification requires the first constructor parameter of 
any closure to be
+      // its enclosing object.
+      val clone = instantiateClass(cls, parent, inInterpreter)
       for (fieldName <- accessedFields(cls)) {
         val field = cls.getDeclaredField(fieldName)
         field.setAccessible(true)
         val value = field.get(obj)
-        // logInfo("1: Setting " + fieldName + " on " + cls + " to " + value);
-        field.set(outer, value)
+        field.set(clone, value)
+      }
+      // If transitive cleaning is enabled, we recursively clean any enclosing 
closure using
+      // the already populated accessed fields map of the starting closure
+      if (cleanTransitively && isClosure(clone.getClass)) {
+        logDebug(s" + cleaning cloned closure $clone recursively 
(${cls.getName})")
+        // No need to check serializable here for the outer closures because 
we're
+        // only interested in the serializability of the starting closure
+        clean(clone, checkSerializable = false, cleanTransitively, 
accessedFields)
       }
+      parent = clone
     }
 
-    if (outer != null) {
-      // logInfo("2: Setting $outer on " + func.getClass + " to " + outer);
+    // Update the parent pointer ($outer) of this closure
+    if (parent != null) {
       val field = func.getClass.getDeclaredField("$outer")
       field.setAccessible(true)
-      field.set(func, outer)
+      // If the starting closure doesn't actually need our enclosing object, 
then just null it out
+      if (accessedFields.contains(func.getClass) &&
+        !accessedFields(func.getClass).contains("$outer")) {
+        logDebug(s" + the starting closure doesn't actually need $parent, so 
we null it out")
+        field.set(func, null)
+      } else {
+        // Update this closure's parent pointer to point to our enclosing 
object,
+        // which could either be a cloned closure or the original user object
+        field.set(func, parent)
+      }
     }
-    
+
+    logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned 
+++")
+
     if (checkSerializable) {
       ensureSerializable(func)
     }
@@ -167,15 +318,17 @@ private[spark] object ClosureCleaner extends Logging {
     }
   }
 
-  private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: 
Boolean): AnyRef = {
-    // logInfo("Creating a " + cls + " with outer = " + outer)
+  private def instantiateClass(
+      cls: Class[_],
+      enclosingObject: AnyRef,
+      inInterpreter: Boolean): AnyRef = {
     if (!inInterpreter) {
       // This is a bona fide closure class, whose constructor has no effects
       // other than to set its fields, so use its constructor
       val cons = cls.getConstructors()(0)
       val params = cons.getParameterTypes.map(createNullValue).toArray
-      if (outer != null) {
-        params(0) = outer // First param is always outer object
+      if (enclosingObject != null) {
+        params(0) = enclosingObject // First param is always enclosing object
       }
       return cons.newInstance(params: _*).asInstanceOf[AnyRef]
     } else {
@@ -184,19 +337,17 @@ private[spark] object ClosureCleaner extends Logging {
       val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
       val newCtor = rf.newConstructorForSerialization(cls, parentCtor)
       val obj = newCtor.newInstance().asInstanceOf[AnyRef]
-      if (outer != null) {
-        // logInfo("3: Setting $outer on " + cls + " to " + outer);
+      if (enclosingObject != null) {
         val field = cls.getDeclaredField("$outer")
         field.setAccessible(true)
-        field.set(obj, outer)
+        field.set(obj, enclosingObject)
       }
       obj
     }
   }
 }
 
-private[spark]
-class ReturnStatementFinder extends ClassVisitor(ASM4) {
+private class ReturnStatementFinder extends ClassVisitor(ASM4) {
   override def visitMethod(access: Int, name: String, desc: String,
       sig: String, exceptions: Array[String]): MethodVisitor = {
     if (name.contains("apply")) {
@@ -213,26 +364,65 @@ class ReturnStatementFinder extends ClassVisitor(ASM4) {
   }
 }
 
-private[spark]
-class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends 
ClassVisitor(ASM4) {
-  override def visitMethod(access: Int, name: String, desc: String,
-      sig: String, exceptions: Array[String]): MethodVisitor = {
+/** Helper class to identify a method. */
+private case class MethodIdentifier[T](cls: Class[T], name: String, desc: 
String)
+
+/**
+ * Find the fields accessed by a given class.
+ *
+ * The resulting fields are stored in the mutable map passed in through the 
constructor.
+ * This map is assumed to have its keys already populated with the classes of 
interest.
+ *
+ * @param fields the mutable map that stores the fields to return
+ * @param findTransitively if true, find fields indirectly referenced through 
method calls
+ * @param specificMethod if not empty, visit only this specific method
+ * @param visitedMethods a set of visited methods to avoid cycles
+ */
+private[util] class FieldAccessFinder(
+    fields: Map[Class[_], Set[String]],
+    findTransitively: Boolean,
+    specificMethod: Option[MethodIdentifier[_]] = None,
+    visitedMethods: Set[MethodIdentifier[_]] = Set.empty)
+  extends ClassVisitor(ASM4) {
+
+  override def visitMethod(
+      access: Int,
+      name: String,
+      desc: String,
+      sig: String,
+      exceptions: Array[String]): MethodVisitor = {
+
+    // If we are told to visit only a certain method and this is not the one, 
ignore it
+    if (specificMethod.isDefined &&
+        (specificMethod.get.name != name || specificMethod.get.desc != desc)) {
+      return null
+    }
+
     new MethodVisitor(ASM4) {
       override def visitFieldInsn(op: Int, owner: String, name: String, desc: 
String) {
         if (op == GETFIELD) {
-          for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
-            output(cl) += name
+          for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+            fields(cl) += name
           }
         }
       }
 
-      override def visitMethodInsn(op: Int, owner: String, name: String,
-          desc: String) {
-        // Check for calls a getter method for a variable in an interpreter 
wrapper object.
-        // This means that the corresponding field will be accessed, so we 
should save it.
-        if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && 
!name.endsWith("$outer")) {
-          for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
-            output(cl) += name
+      override def visitMethodInsn(op: Int, owner: String, name: String, desc: 
String) {
+        for (cl <- fields.keys if cl.getName == owner.replace('/', '.')) {
+          // Check for calls a getter method for a variable in an interpreter 
wrapper object.
+          // This means that the corresponding field will be accessed, so we 
should save it.
+          if (op == INVOKEVIRTUAL && owner.endsWith("$iwC") && 
!name.endsWith("$outer")) {
+            fields(cl) += name
+          }
+          // Optionally visit other methods to find fields that are 
transitively referenced
+          if (findTransitively) {
+            val m = MethodIdentifier(cl, name, desc)
+            if (!visitedMethods.contains(m)) {
+              // Keep track of visited methods to avoid potential infinite 
cycles
+              visitedMethods += m
+              ClosureCleaner.getClassReader(cl).accept(
+                new FieldAccessFinder(fields, findTransitively, Some(m), 
visitedMethods), 0)
+            }
           }
         }
       }
@@ -240,9 +430,14 @@ class FieldAccessFinder(output: Map[Class[_], 
Set[String]]) extends ClassVisitor
   }
 }
 
-private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends 
ClassVisitor(ASM4) {
+private class InnerClosureFinder(output: Set[Class[_]]) extends 
ClassVisitor(ASM4) {
   var myName: String = null
 
+  // TODO: Recursively find inner closures that we indirectly reference, e.g.
+  //   val closure1 = () = { () => 1 }
+  //   val closure2 = () => { (1 to 5).map(closure1) }
+  // The second closure technically has two inner closures, but this finder 
only finds one
+
   override def visit(version: Int, access: Int, name: String, sig: String,
       superName: String, interfaces: Array[String]) {
     myName = name

http://git-wip-us.apache.org/repos/asf/spark/blob/7394e7ad/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala 
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index c471627..ff1bfe0 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -50,7 +50,7 @@ class ClosureCleanerSuite extends FunSuite {
     val obj = new TestClassWithNesting(1)
     assert(obj.run() === 96) // 4 * (1+2+3+4) + 4 * (1+2+3+4) + 16 * 1
   }
-  
+
   test("toplevel return statements in closures are identified at cleaning 
time") {
     val ex = intercept[SparkException] {
       TestObjectWithBogusReturns.run()
@@ -61,13 +61,20 @@ class ClosureCleanerSuite extends FunSuite {
 
   test("return statements from named functions nested in closures don't raise 
exceptions") {
     val result = TestObjectWithNestedReturns.run()
-    assert(result == 1)
+    assert(result === 1)
   }
 }
 
 // A non-serializable class we create in closures to make sure that we aren't
 // keeping references to unneeded variables from our outer closures.
-class NonSerializable {}
+class NonSerializable(val id: Int = -1) {
+  override def equals(other: Any): Boolean = {
+    other match {
+      case o: NonSerializable => id == o.id
+      case _ => false
+    }
+  }
+}
 
 object TestObject {
   def run(): Int = {

http://git-wip-us.apache.org/repos/asf/spark/blob/7394e7ad/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala 
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
new file mode 100644
index 0000000..5945679
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -0,0 +1,571 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io.NotSerializableException
+
+import scala.collection.mutable
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite, PrivateMethodTester}
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.serializer.SerializerInstance
+
+/**
+ * Another test suite for the closure cleaner that is finer-grained.
+ * For tests involving end-to-end Spark jobs, see {{ClosureCleanerSuite}}.
+ */
+class ClosureCleanerSuite2 extends FunSuite with BeforeAndAfterAll with 
PrivateMethodTester {
+
+  // Start a SparkContext so that the closure serializer is accessible
+  // We do not actually use this explicitly otherwise
+  private var sc: SparkContext = null
+  private var closureSerializer: SerializerInstance = null
+
+  override def beforeAll(): Unit = {
+    sc = new SparkContext("local", "test")
+    closureSerializer = sc.env.closureSerializer.newInstance()
+  }
+
+  override def afterAll(): Unit = {
+    sc.stop()
+    sc = null
+    closureSerializer = null
+  }
+
+  // Some fields and methods to reference in inner closures later
+  private val someSerializableValue = 1
+  private val someNonSerializableValue = new NonSerializable
+  private def someSerializableMethod() = 1
+  private def someNonSerializableMethod() = new NonSerializable
+
+  /** Assert that the given closure is serializable (or not). */
+  private def assertSerializable(closure: AnyRef, serializable: Boolean): Unit 
= {
+    if (serializable) {
+      closureSerializer.serialize(closure)
+    } else {
+      intercept[NotSerializableException] {
+        closureSerializer.serialize(closure)
+      }
+    }
+  }
+
+  /**
+   * Helper method for testing whether closure cleaning works as expected.
+   * This cleans the given closure twice, with and without transitive cleaning.
+   *
+   * @param closure closure to test cleaning with
+   * @param serializableBefore if true, verify that the closure is serializable
+   *                           before cleaning, otherwise assert that it is not
+   * @param serializableAfter if true, assert that the closure is serializable
+   *                          after cleaning otherwise assert that it is not
+   */
+  private def verifyCleaning(
+      closure: AnyRef,
+      serializableBefore: Boolean,
+      serializableAfter: Boolean): Unit = {
+    verifyCleaning(closure, serializableBefore, serializableAfter, transitive 
= true)
+    verifyCleaning(closure, serializableBefore, serializableAfter, transitive 
= false)
+  }
+
+  /** Helper method for testing whether closure cleaning works as expected. */
+  private def verifyCleaning(
+      closure: AnyRef,
+      serializableBefore: Boolean,
+      serializableAfter: Boolean,
+      transitive: Boolean): Unit = {
+    assertSerializable(closure, serializableBefore)
+    // If the resulting closure is not serializable even after
+    // cleaning, we expect ClosureCleaner to throw a SparkException
+    if (serializableAfter) {
+      ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+    } else {
+      intercept[SparkException] {
+        ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+      }
+    }
+    assertSerializable(closure, serializableAfter)
+  }
+
+  /**
+   * Return the fields accessed by the given closure by class.
+   * This also optionally finds the fields transitively referenced through 
methods invocations.
+   */
+  private def findAccessedFields(
+      closure: AnyRef,
+      outerClasses: Seq[Class[_]],
+      findTransitively: Boolean): Map[Class[_], Set[String]] = {
+    val fields = new mutable.HashMap[Class[_], mutable.Set[String]]
+    outerClasses.foreach { c => fields(c) = new mutable.HashSet[String] }
+    ClosureCleaner.getClassReader(closure.getClass)
+      .accept(new FieldAccessFinder(fields, findTransitively), 0)
+    fields.mapValues(_.toSet).toMap
+  }
+
+  // Accessors for private methods
+  private val _isClosure = PrivateMethod[Boolean]('isClosure)
+  private val _getInnerClosureClasses = 
PrivateMethod[List[Class[_]]]('getInnerClosureClasses)
+  private val _getOuterClasses = 
PrivateMethod[List[Class[_]]]('getOuterClasses)
+  private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects)
+
+  private def isClosure(obj: AnyRef): Boolean = {
+    ClosureCleaner invokePrivate _isClosure(obj)
+  }
+
+  private def getInnerClosureClasses(closure: AnyRef): List[Class[_]] = {
+    ClosureCleaner invokePrivate _getInnerClosureClasses(closure)
+  }
+
+  private def getOuterClasses(closure: AnyRef): List[Class[_]] = {
+    ClosureCleaner invokePrivate _getOuterClasses(closure)
+  }
+
+  private def getOuterObjects(closure: AnyRef): List[AnyRef] = {
+    ClosureCleaner invokePrivate _getOuterObjects(closure)
+  }
+
+  test("get inner closure classes") {
+    val closure1 = () => 1
+    val closure2 = () => { () => 1 }
+    val closure3 = (i: Int) => {
+      (1 to i).map { x => x + 1 }.filter { x => x > 5 }
+    }
+    val closure4 = (j: Int) => {
+      (1 to j).flatMap { x =>
+        (1 to x).flatMap { y =>
+          (1 to y).map { z => z + 1 }
+        }
+      }
+    }
+    val inner1 = getInnerClosureClasses(closure1)
+    val inner2 = getInnerClosureClasses(closure2)
+    val inner3 = getInnerClosureClasses(closure3)
+    val inner4 = getInnerClosureClasses(closure4)
+    assert(inner1.isEmpty)
+    assert(inner2.size === 1)
+    assert(inner3.size === 2)
+    assert(inner4.size === 3)
+    assert(inner2.forall(isClosure))
+    assert(inner3.forall(isClosure))
+    assert(inner4.forall(isClosure))
+  }
+
+  test("get outer classes and objects") {
+    val localValue = someSerializableValue
+    val closure1 = () => 1
+    val closure2 = () => localValue
+    val closure3 = () => someSerializableValue
+    val closure4 = () => someSerializableMethod()
+    val outerClasses1 = getOuterClasses(closure1)
+    val outerClasses2 = getOuterClasses(closure2)
+    val outerClasses3 = getOuterClasses(closure3)
+    val outerClasses4 = getOuterClasses(closure4)
+    val outerObjects1 = getOuterObjects(closure1)
+    val outerObjects2 = getOuterObjects(closure2)
+    val outerObjects3 = getOuterObjects(closure3)
+    val outerObjects4 = getOuterObjects(closure4)
+
+    // The classes and objects should have the same size
+    assert(outerClasses1.size === outerObjects1.size)
+    assert(outerClasses2.size === outerObjects2.size)
+    assert(outerClasses3.size === outerObjects3.size)
+    assert(outerClasses4.size === outerObjects4.size)
+
+    // These do not have $outer pointers because they reference only local 
variables
+    assert(outerClasses1.isEmpty)
+    assert(outerClasses2.isEmpty)
+
+    // These closures do have $outer pointers because they ultimately 
reference `this`
+    // The first $outer pointer refers to the closure defines this test (see 
FunSuite#test)
+    // The second $outer pointer refers to ClosureCleanerSuite2
+    assert(outerClasses3.size === 2)
+    assert(outerClasses4.size === 2)
+    assert(isClosure(outerClasses3(0)))
+    assert(isClosure(outerClasses4(0)))
+    assert(outerClasses3(0) === outerClasses4(0)) // part of the same 
"FunSuite#test" scope
+    assert(outerClasses3(1) === this.getClass)
+    assert(outerClasses4(1) === this.getClass)
+    assert(outerObjects3(1) === this)
+    assert(outerObjects4(1) === this)
+  }
+
+  test("get outer classes and objects with nesting") {
+    val localValue = someSerializableValue
+
+    val test1 = () => {
+      val x = 1
+      val closure1 = () => 1
+      val closure2 = () => x
+      val outerClasses1 = getOuterClasses(closure1)
+      val outerClasses2 = getOuterClasses(closure2)
+      val outerObjects1 = getOuterObjects(closure1)
+      val outerObjects2 = getOuterObjects(closure2)
+      assert(outerClasses1.size === outerObjects1.size)
+      assert(outerClasses2.size === outerObjects2.size)
+      // These inner closures only reference local variables, and so do not 
have $outer pointers
+      assert(outerClasses1.isEmpty)
+      assert(outerClasses2.isEmpty)
+    }
+
+    val test2 = () => {
+      def y = 1
+      val closure1 = () => 1
+      val closure2 = () => y
+      val closure3 = () => localValue
+      val outerClasses1 = getOuterClasses(closure1)
+      val outerClasses2 = getOuterClasses(closure2)
+      val outerClasses3 = getOuterClasses(closure3)
+      val outerObjects1 = getOuterObjects(closure1)
+      val outerObjects2 = getOuterObjects(closure2)
+      val outerObjects3 = getOuterObjects(closure3)
+      assert(outerClasses1.size === outerObjects1.size)
+      assert(outerClasses2.size === outerObjects2.size)
+      assert(outerClasses3.size === outerObjects3.size)
+      // Same as above, this closure only references local variables
+      assert(outerClasses1.isEmpty)
+      // This closure references the "test2" scope because it needs to find 
the method `y`
+      // Scope hierarchy: "test2" < "FunSuite#test" < ClosureCleanerSuite2
+      assert(outerClasses2.size === 3)
+      // This closure references the "test2" scope because it needs to find 
the `localValue`
+      // defined outside of this scope
+      assert(outerClasses3.size === 3)
+      assert(isClosure(outerClasses2(0)))
+      assert(isClosure(outerClasses3(0)))
+      assert(isClosure(outerClasses2(1)))
+      assert(isClosure(outerClasses3(1)))
+      assert(outerClasses2(0) === outerClasses3(0)) // part of the same 
"test2" scope
+      assert(outerClasses2(1) === outerClasses3(1)) // part of the same 
"FunSuite#test" scope
+      assert(outerClasses2(2) === this.getClass)
+      assert(outerClasses3(2) === this.getClass)
+      assert(outerObjects2(2) === this)
+      assert(outerObjects3(2) === this)
+    }
+
+    test1()
+    test2()
+  }
+
+  test("find accessed fields") {
+    val localValue = someSerializableValue
+    val closure1 = () => 1
+    val closure2 = () => localValue
+    val closure3 = () => someSerializableValue
+    val outerClasses1 = getOuterClasses(closure1)
+    val outerClasses2 = getOuterClasses(closure2)
+    val outerClasses3 = getOuterClasses(closure3)
+
+    val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively 
= false)
+    val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively 
= false)
+    val fields3 = findAccessedFields(closure3, outerClasses3, findTransitively 
= false)
+    assert(fields1.isEmpty)
+    assert(fields2.isEmpty)
+    assert(fields3.size === 2)
+    // This corresponds to the "FunSuite#test" closure. This is empty because 
the
+    // `someSerializableValue` belongs to its parent (i.e. 
ClosureCleanerSuite2).
+    assert(fields3(outerClasses3(0)).isEmpty)
+    // This corresponds to the ClosureCleanerSuite2. This is also empty, 
however,
+    // because accessing a `ClosureCleanerSuite2#someSerializableValue` 
actually involves a
+    // method call. Since we do not find fields transitively, we will not 
recursively trace
+    // through the fields referenced by this method.
+    assert(fields3(outerClasses3(1)).isEmpty)
+
+    val fields1t = findAccessedFields(closure1, outerClasses1, 
findTransitively = true)
+    val fields2t = findAccessedFields(closure2, outerClasses2, 
findTransitively = true)
+    val fields3t = findAccessedFields(closure3, outerClasses3, 
findTransitively = true)
+    assert(fields1t.isEmpty)
+    assert(fields2t.isEmpty)
+    assert(fields3t.size === 2)
+    // Because we find fields transitively now, we are able to detect that we 
need the
+    // $outer pointer to get the field from the ClosureCleanerSuite2
+    assert(fields3t(outerClasses3(0)).size === 1)
+    assert(fields3t(outerClasses3(0)).head === "$outer")
+    assert(fields3t(outerClasses3(1)).size === 1)
+    assert(fields3t(outerClasses3(1)).head.contains("someSerializableValue"))
+  }
+
+  test("find accessed fields with nesting") {
+    val localValue = someSerializableValue
+
+    val test1 = () => {
+      def a = localValue + 1
+      val closure1 = () => 1
+      val closure2 = () => a
+      val closure3 = () => localValue
+      val closure4 = () => someSerializableValue
+      val outerClasses1 = getOuterClasses(closure1)
+      val outerClasses2 = getOuterClasses(closure2)
+      val outerClasses3 = getOuterClasses(closure3)
+      val outerClasses4 = getOuterClasses(closure4)
+
+      // First, find only fields accessed directly, not transitively, by these 
closures
+      val fields1 = findAccessedFields(closure1, outerClasses1, 
findTransitively = false)
+      val fields2 = findAccessedFields(closure2, outerClasses2, 
findTransitively = false)
+      val fields3 = findAccessedFields(closure3, outerClasses3, 
findTransitively = false)
+      val fields4 = findAccessedFields(closure4, outerClasses4, 
findTransitively = false)
+      assert(fields1.isEmpty)
+      // Note that the size here represents the number of outer classes, not 
the number of fields
+      // "test1" < parameter of "FunSuite#test" < ClosureCleanerSuite2
+      assert(fields2.size === 3)
+      // Since we do not find fields transitively here, we do not look into 
what `def a` references
+      assert(fields2(outerClasses2(0)).isEmpty) // This corresponds to the 
"test1" scope
+      assert(fields2(outerClasses2(1)).isEmpty) // This corresponds to the 
"FunSuite#test" scope
+      assert(fields2(outerClasses2(2)).isEmpty) // This corresponds to the 
ClosureCleanerSuite2
+      assert(fields3.size === 3)
+      // Note that `localValue` is a field of the "test1" scope because `def 
a` references it,
+      // but NOT a field of the "FunSuite#test" scope because it is only a 
local variable there
+      assert(fields3(outerClasses3(0)).size === 1)
+      assert(fields3(outerClasses3(0)).head.contains("localValue"))
+      assert(fields3(outerClasses3(1)).isEmpty)
+      assert(fields3(outerClasses3(2)).isEmpty)
+      assert(fields4.size === 3)
+      // Because `val someSerializableValue` is an instance variable, even an 
explicit reference
+      // here actually involves a method call to access the underlying value 
of the variable.
+      // Because we are not finding fields transitively here, we do not 
consider the fields
+      // accessed by this "method" (i.e. the val's accessor).
+      assert(fields4(outerClasses4(0)).isEmpty)
+      assert(fields4(outerClasses4(1)).isEmpty)
+      assert(fields4(outerClasses4(2)).isEmpty)
+
+      // Now do the same, but find fields that the closures transitively 
reference
+      val fields1t = findAccessedFields(closure1, outerClasses1, 
findTransitively = true)
+      val fields2t = findAccessedFields(closure2, outerClasses2, 
findTransitively = true)
+      val fields3t = findAccessedFields(closure3, outerClasses3, 
findTransitively = true)
+      val fields4t = findAccessedFields(closure4, outerClasses4, 
findTransitively = true)
+      assert(fields1t.isEmpty)
+      assert(fields2t.size === 3)
+      assert(fields2t(outerClasses2(0)).size === 1) // `def a` references 
`localValue`
+      assert(fields2t(outerClasses2(0)).head.contains("localValue"))
+      assert(fields2t(outerClasses2(1)).isEmpty)
+      assert(fields2t(outerClasses2(2)).isEmpty)
+      assert(fields3t.size === 3)
+      assert(fields3t(outerClasses3(0)).size === 1) // as before
+      assert(fields3t(outerClasses3(0)).head.contains("localValue"))
+      assert(fields3t(outerClasses3(1)).isEmpty)
+      assert(fields3t(outerClasses3(2)).isEmpty)
+      assert(fields4t.size === 3)
+      // Through a series of method calls, we are able to detect that we 
ultimately access
+      // ClosureCleanerSuite2's field `someSerializableValue`. Along the way, 
we also accessed
+      // a few $outer parent pointers to get to the outermost object.
+      assert(fields4t(outerClasses4(0)) === Set("$outer"))
+      assert(fields4t(outerClasses4(1)) === Set("$outer"))
+      assert(fields4t(outerClasses4(2)).size === 1)
+      assert(fields4t(outerClasses4(2)).head.contains("someSerializableValue"))
+    }
+
+    test1()
+  }
+
+  test("clean basic serializable closures") {
+    val localValue = someSerializableValue
+    val closure1 = () => 1
+    val closure2 = () => Array[String]("a", "b", "c")
+    val closure3 = (s: String, arr: Array[Long]) => s + arr.mkString(", ")
+    val closure4 = () => localValue
+    val closure5 = () => new NonSerializable(5) // we're just serializing the 
class information
+    val closure1r = closure1()
+    val closure2r = closure2()
+    val closure3r = closure3("g", Array(1, 5, 8))
+    val closure4r = closure4()
+    val closure5r = closure5()
+
+    verifyCleaning(closure1, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure2, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure3, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure4, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure5, serializableBefore = true, serializableAfter = 
true)
+
+    // Verify that closures can still be invoked and the result still the same
+    assert(closure1() === closure1r)
+    assert(closure2() === closure2r)
+    assert(closure3("g", Array(1, 5, 8)) === closure3r)
+    assert(closure4() === closure4r)
+    assert(closure5() === closure5r)
+  }
+
+  test("clean basic non-serializable closures") {
+    val closure1 = () => this // ClosureCleanerSuite2 is not serializable
+    val closure5 = () => someSerializableValue
+    val closure3 = () => someSerializableMethod()
+    val closure4 = () => someNonSerializableValue
+    val closure2 = () => someNonSerializableMethod()
+
+    // These are not cleanable because they ultimately reference the 
ClosureCleanerSuite2
+    verifyCleaning(closure1, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure2, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure3, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure4, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure5, serializableBefore = false, serializableAfter = 
false)
+  }
+
+  test("clean basic nested serializable closures") {
+    val localValue = someSerializableValue
+    val closure1 = (i: Int) => {
+      (1 to i).map { x => x + localValue } // 1 level of nesting
+    }
+    val closure2 = (j: Int) => {
+      (1 to j).flatMap { x =>
+        (1 to x).map { y => y + localValue } // 2 levels
+      }
+    }
+    val closure3 = (k: Int, l: Int, m: Int) => {
+      (1 to k).flatMap(closure2) ++ // 4 levels
+      (1 to l).flatMap(closure1) ++ // 3 levels
+      (1 to m).map { x => x + 1 } // 2 levels
+    }
+    val closure1r = closure1(1)
+    val closure2r = closure2(2)
+    val closure3r = closure3(3, 4, 5)
+
+    verifyCleaning(closure1, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure2, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure3, serializableBefore = true, serializableAfter = 
true)
+
+    // Verify that closures can still be invoked and the result still the same
+    assert(closure1(1) === closure1r)
+    assert(closure2(2) === closure2r)
+    assert(closure3(3, 4, 5) === closure3r)
+  }
+
+  test("clean basic nested non-serializable closures") {
+    def localSerializableMethod(): Int = someSerializableValue
+    val localNonSerializableValue = someNonSerializableValue
+    // These closures ultimately reference the ClosureCleanerSuite2
+    // Note that even accessing `val` that is an instance variable involves a 
method call
+    val closure1 = (i: Int) => { (1 to i).map { x => x + someSerializableValue 
} }
+    val closure2 = (j: Int) => { (1 to j).map { x => x + 
someSerializableMethod() } }
+    val closure4 = (k: Int) => { (1 to k).map { x => x + 
localSerializableMethod() } }
+    // This closure references a local non-serializable value
+    val closure3 = (l: Int) => { (1 to l).map { x => localNonSerializableValue 
} }
+    // This is non-serializable no matter how many levels we nest it
+    val closure5 = (m: Int) => {
+      (1 to m).foreach { x =>
+        (1 to x).foreach { y =>
+          (1 to y).foreach { z =>
+            someSerializableValue
+          }
+        }
+      }
+    }
+
+    verifyCleaning(closure1, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure2, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure3, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure4, serializableBefore = false, serializableAfter = 
false)
+    verifyCleaning(closure5, serializableBefore = false, serializableAfter = 
false)
+  }
+
+  test("clean complicated nested serializable closures") {
+    val localValue = someSerializableValue
+
+    // Here we assume that if the outer closure is serializable,
+    // then all inner closures must also be serializable
+
+    // Reference local fields from all levels
+    val closure1 = (i: Int) => {
+      val a = 1
+      (1 to i).flatMap { x =>
+        val b = a + 1
+        (1 to x).map { y =>
+          y + a + b + localValue
+        }
+      }
+    }
+
+    // Reference local fields and methods from all levels within the outermost 
closure
+    val closure2 = (i: Int) => {
+      val a1 = 1
+      def a2 = 2
+      (1 to i).flatMap { x =>
+        val b1 = a1 + 1
+        def b2 = a2 + 1
+        (1 to x).map { y =>
+          // If this references a method outside the outermost closure, then 
it will try to pull
+          // in the ClosureCleanerSuite2. This is why `localValue` here must 
be a local `val`.
+          y + a1 + a2 + b1 + b2 + localValue
+        }
+      }
+    }
+
+    val closure1r = closure1(1)
+    val closure2r = closure2(2)
+    verifyCleaning(closure1, serializableBefore = true, serializableAfter = 
true)
+    verifyCleaning(closure2, serializableBefore = true, serializableAfter = 
true)
+    assert(closure1(1) == closure1r)
+    assert(closure2(2) == closure2r)
+  }
+
+  test("clean complicated nested non-serializable closures") {
+    val localValue = someSerializableValue
+
+    // Note that we are not interested in cleaning the outer closures here 
(they are not cleanable)
+    // The only reason why they exist is to nest the inner closures
+
+    val test1 = () => {
+      val a = localValue
+      val b = sc
+      val inner1 = (x: Int) => x + a + b.hashCode()
+      val inner2 = (x: Int) => x + a
+
+      // This closure explicitly references a non-serializable field
+      // There is no way to clean it
+      verifyCleaning(inner1, serializableBefore = false, serializableAfter = 
false)
+
+      // This closure is serializable to begin with since it does not need a 
pointer to
+      // the outer closure (it only references local variables)
+      verifyCleaning(inner2, serializableBefore = true, serializableAfter = 
true)
+    }
+
+    // Same as above, but the `val a` becomes `def a`
+    // The difference here is that all inner closures now have pointers to the 
outer closure
+    val test2 = () => {
+      def a = localValue
+      val b = sc
+      val inner1 = (x: Int) => x + a + b.hashCode()
+      val inner2 = (x: Int) => x + a
+
+      // As before, this closure is neither serializable nor cleanable
+      verifyCleaning(inner1, serializableBefore = false, serializableAfter = 
false)
+
+      // This closure is no longer serializable because it now has a pointer 
to the outer closure,
+      // which is itself not serializable because it has a pointer to the 
ClosureCleanerSuite2.
+      // If we do not clean transitively, we will not null out this indirect 
reference.
+      verifyCleaning(
+        inner2, serializableBefore = false, serializableAfter = false, 
transitive = false)
+
+      // If we clean transitively, we will find that method `a` does not 
actually reference the
+      // outer closure's parent (i.e. the ClosureCleanerSuite), so we can 
additionally null out
+      // the outer closure's parent pointer. This will make `inner2` 
serializable.
+      verifyCleaning(
+        inner2, serializableBefore = false, serializableAfter = true, 
transitive = true)
+    }
+
+    // Same as above, but with more levels of nesting
+    val test3 = () => { () => test1() }
+    val test4 = () => { () => test2() }
+    val test5 = () => { () => { () => test3() } }
+    val test6 = () => { () => { () => test4() } }
+
+    test1()
+    test2()
+    test3()()
+    test4()()
+    test5()()()
+    test6()()()
+  }
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to