rednaxelafx commented on a change in pull request #28463:
URL: https://github.com/apache/spark/pull/28463#discussion_r422434842



##########
File path: core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
##########
@@ -414,6 +434,296 @@ private[spark] object ClosureCleaner extends Logging {
   }
 }
 
+private[spark] object IndylambdaScalaClosures extends Logging {
+  // internal name of java.lang.invoke.LambdaMetafactory
+  val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory"
+  // the method that Scala indylambda use for bootstrap method
+  val LambdaMetafactoryMethodName = "altMetafactory"
+  val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" 
+
+    "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" +
+    "Ljava/lang/invoke/CallSite;"
+
+  /**
+   * Check if the given reference is a indylambda style Scala closure.
+   * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization 
proxy
+   * (SerializedLambda) of the closure;
+   * otherwise (e.g. for Scala 2.11 closures) return None.
+   *
+   * @param maybeClosure the closure to check.
+   */
+  def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = {
+    def isClosureCandidate(cls: Class[_]): Boolean = {
+      // TODO: maybe lift this restriction to support other functional 
interfaces in the future
+      val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala
+      implementedInterfaces.exists(_.getName.startsWith("scala.Function"))
+    }
+
+    maybeClosure.getClass match {
+      // shortcut the fast check:
+      // 1. indylambda closure classes are generated by Java's 
LambdaMetafactory, and they're
+      //    always synthetic.
+      // 2. We only care about Serializable closures, so let's check that as 
well
+      case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => 
None
+
+      case c if isClosureCandidate(c) =>
+        try {
+          Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure)
+        } catch {
+          case e: Exception =>
+            logDebug("The given reference is not an indylambda Scala 
closure.", e)
+            None
+        }
+
+      case _ => None
+    }
+  }
+
+  def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = {
+    lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic &&
+      lambdaProxy.getImplMethodName.contains("$anonfun$")
+  }
+
+  def inspect(closure: AnyRef): SerializedLambda = {
+    val writeReplace = closure.getClass.getDeclaredMethod("writeReplace")
+    writeReplace.setAccessible(true)
+    writeReplace.invoke(closure).asInstanceOf[SerializedLambda]
+  }
+
+  /**
+   * Check if the handle represents the LambdaMetafactory that indylambda 
Scala closures
+   * use for creating the lambda class and getting a closure instance.
+   */
+  def isLambdaMetafactory(bsmHandle: Handle): Boolean = {
+    bsmHandle.getOwner == LambdaMetafactoryClassName &&
+      bsmHandle.getName == LambdaMetafactoryMethodName &&
+      bsmHandle.getDesc == LambdaMetafactoryMethodDesc
+  }
+
+  /**
+   * Check if the handle represents a target method that is:
+   * - a STATIC method that implements a Scala lambda body in the indylambda 
style
+   * - captures the enclosing `this`, i.e. the first argument is a reference 
to the same type as
+   *   the owning class.
+   * Returns true if both criteria above are met.
+   */
+  def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): 
Boolean = {
+    handle.getTag == H_INVOKESTATIC &&
+      handle.getName.contains("$anonfun$") &&
+      handle.getOwner == ownerInternalName &&
+      handle.getDesc.startsWith(s"(L$ownerInternalName;")
+  }
+
+  /**
+   * Check if the callee of a call site is a inner class constructor.
+   * - A constructor has to be invoked via INVOKESPECIAL
+   * - A constructor's internal name is "<init>" and the return type is 
"V" (void)
+   * - An inner class' first argument in the signature has to be a reference 
to the
+   *   enclosing "this", aka `$outer` in Scala.
+   */
+  def isInnerClassCtorCapturingOuter(
+      op: Int, owner: String, name: String, desc: String, callerInternalName: 
String): Boolean = {
+    op == INVOKESPECIAL && name == "<init>" && 
desc.startsWith(s"(L$callerInternalName;")
+  }
+
+  /**
+   * Scans an indylambda Scala closure, along with its lexically nested 
closures, and populate
+   * the accessed fields info on which fields on the outer object are accessed.
+   *
+   * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + 
FieldAccessFinder fused
+   * into one for processing indylambda closures. The traversal order along 
the call graph is the
+   * same for all three combined, so they can be fused together easily while 
maintaining the same
+   * ordering as the existing implementation.
+   *
+   * Precondition: this function expects the `accessedFields` to be populated 
with all known
+   *               outer classes and their super classes to be in the map as 
keys, e.g.
+   *               initializing via ClosureCleaner.initAccessedFields.
+   */
+  // scalastyle:off line.size.limit
+  // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+:
+  //   val topLevelValue = "someValue"; val closure = (j: Int) => {
+  //     class InnerFoo {
+  //       val innerClosure = (x: Int) => (1 to x).map { y => y + 
topLevelValue }
+  //     }
+  //     val innerFoo = new InnerFoo
+  //     (1 to j).flatMap(innerFoo.innerClosure)
+  //   }
+  //   sc.parallelize(0 to 2).map(closure).collect
+  //
+  // produces the following trace-level logs:
+  // (slightly simplified:
+  //   - omitting the "ignoring ..." lines;
+  //   - "$iw" is actually "$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw";
+  //   - "invokedynamic" lines are simplified to just show the name+desc, 
omitting the bsm info)
+  //   Cleaning indylambda closure: $anonfun$closure$1$adapted
+  //     scanning 
$iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq;
+  //       found intra class call to 
$iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq;
+  //     scanning 
$iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq;
+  //       found inner class $iw$InnerFoo$1
+  //       found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1;
+  //     scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1;
+  //     scanning 
$iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;
+  //       invokedynamic: 
lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, 
bsm...)
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq;
+  //       found intra class call to 
$iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq;
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq;
+  //       invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...)
+  //       found inner closure 
$iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String;
 (6)
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String;
+  //       found intra class call to 
$iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String;
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String;
+  //       found call to outer $iw.topLevelValue()Ljava/lang/String;
+  //     scanning $iw.topLevelValue()Ljava/lang/String;
+  //       found field access topLevelValue on $iw
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String;
+  //       found intra class call to 
$iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String;
+  //     scanning $iw$InnerFoo$1.<init>(L$iw;)V
+  //       invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...)
+  //       found inner closure 
$iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq;
 (6)
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq;
+  //       invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...)
+  //       found inner closure 
$iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String;
 (6)
+  //     scanning 
$iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String;
+  //       found call to outer $iw.topLevelValue()Ljava/lang/String;
+  //     scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1;
+  //    + fields accessed by starting closure: 2 classes
+  //        (class java.lang.Object,Set())
+  //        (class $iw,Set(topLevelValue))
+  //    + cloning instance of REPL class $iw
+  //    +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++
+  //
+  // scalastyle:on line.size.limit
+  def findAccessedFields(
+      lambdaProxy: SerializedLambda,
+      lambdaClassLoader: ClassLoader,
+      accessedFields: Map[Class[_], Set[String]],
+      findTransitively: Boolean): Unit = {
+
+    // We may need to visit the same class multiple times for different 
methods on it, and we'll
+    // need to lookup by name. So we use ASM's Tree API and cache the 
ClassNode/MethodNode.
+    val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)]
+    val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode]
+    def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) 
= {
+      val classInfo = 
classInfoByInternalName.getOrElseUpdate(classInternalName, {
+        val classExternalName = classInternalName.replace('/', '.')
+        // scalastyle:off classforname
+        val clazz = Class.forName(classExternalName, false, lambdaClassLoader)
+        // scalastyle:on classforname
+        val classNode = new ClassNode()
+        val classReader = ClosureCleaner.getClassReader(clazz)
+        classReader.accept(classNode, 0)
+
+        for (m <- classNode.methods.asScala) {
+          methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
+        }
+
+        (clazz, classNode)
+      })
+      classInfo
+    }
+
+    val implClassInternalName = lambdaProxy.getImplClass
+    val (implClass, _) = getOrUpdateClassInfo(implClassInternalName)
+
+    val implMethodId = MethodIdentifier(
+      implClass, lambdaProxy.getImplMethodName, 
lambdaProxy.getImplMethodSignature)
+
+    // The set of classes that we would consider following the calls into.
+    // Candidates are: known outer class which happens to be the starting 
closure's impl class,
+    // and all inner classes discovered below.
+    val trackedClassesByInternalName = Map[String, 
Class[_]](implClassInternalName -> implClass)
+
+    // Depth-first search for inner closures and track the fields that were 
accessed in them.
+    // Start from the lambda body's implementation method, follow method 
invocations
+    val visited = Set.empty[MethodIdentifier[_]]
+    val stack = Stack[MethodIdentifier[_]](implMethodId)
+    def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = {
+      if (!visited.contains(methodId)) {
+        stack.push(methodId)
+      }
+    }
+
+    while (!stack.isEmpty) {
+      val currentId = stack.pop
+      visited += currentId
+
+      val currentClass = currentId.cls
+      val currentMethodNode = methodNodeById(currentId)
+      logTrace(s"  scanning 
${currentId.cls.getName}.${currentId.name}${currentId.desc}")
+      currentMethodNode.accept(new MethodVisitor(ASM7) {
+        val currentClassName = currentClass.getName
+        val currentClassInternalName = currentClassName.replace('.', '/')
+
+        // Find and update the accessedFields info. Only fields on known outer 
classes are tracked.
+        // This is the FieldAccessFinder equivalent.
+        override def visitFieldInsn(op: Int, owner: String, name: String, 
desc: String): Unit = {
+          if (op == GETFIELD || op == PUTFIELD) {
+            val ownerExternalName = owner.replace('/', '.')
+            for (cl <- accessedFields.keys if cl.getName == ownerExternalName) 
{
+              logTrace(s"    found field access $name on $ownerExternalName")
+              accessedFields(cl) += name
+            }
+          }
+        }
+
+        override def visitMethodInsn(
+            op: Int, owner: String, name: String, desc: String, itf: Boolean): 
Unit = {
+          val ownerExternalName = owner.replace('/', '.')
+          if (owner == currentClassInternalName) {
+            logTrace(s"    found intra class call to 
$ownerExternalName.$name$desc")
+            // could be invoking a helper method or a field accessor method, 
just follow it.
+            pushIfNotVisited(MethodIdentifier(currentClass, name, desc))
+          } else if (isInnerClassCtorCapturingOuter(
+              op, owner, name, desc, currentClassInternalName)) {
+            // Discover inner classes.
+            // This this the InnerClassFinder equivalent for inner classes, 
which still use the
+            // `$outer` chain. So this is NOT controlled by the 
`findTransitively` flag.
+            logTrace(s"    found inner class $ownerExternalName")
+            val innerClassInfo = getOrUpdateClassInfo(owner)
+            val innerClass = innerClassInfo._1
+            val innerClassNode = innerClassInfo._2
+            trackedClassesByInternalName(owner) = innerClass
+            // We need to visit all methods on the inner class so that we 
don't missing anything.
+            for (m <- innerClassNode.methods.asScala) {
+              pushIfNotVisited(MethodIdentifier(innerClass, m.name, m.desc))
+            }
+          } else if (findTransitively && 
trackedClassesByInternalName.contains(owner)) {
+            logTrace(s"    found call to outer $ownerExternalName.$name$desc")

Review comment:
       The "outer" here is relative: for:
   ```
   starting closure
     inner class A
       inner class B
         inner closure
   ```
   To the "inner closure" in this example, both inner class B and A are "outer" 
relative to it. I wanted to make this distinction because I'm only tracking two 
types of calls (relative to current class) and one type of invokedynamic:
   - call to method defined on the same class: always follow
   - call to method defined on some level of outer class: effectively always 
follow, but including the "findTransitively" flag here just to look closer to 
the old code (I might do a future cleanup to move the old code to the new style 
and remove with this flag if possible.
   - invokedynamic where the BSM is LMF and impl method is on the same class. 
Always follow




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to