maropu commented on a change in pull request #28463: URL: https://github.com/apache/spark/pull/28463#discussion_r422425889
########## File path: core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala ########## @@ -414,6 +434,296 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] object IndylambdaScalaClosures extends Logging { + // internal name of java.lang.invoke.LambdaMetafactory + val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory" + // the method that Scala indylambda use for bootstrap method + val LambdaMetafactoryMethodName = "altMetafactory" + val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" + + "Ljava/lang/invoke/CallSite;" + + /** + * Check if the given reference is a indylambda style Scala closure. + * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization proxy + * (SerializedLambda) of the closure; + * otherwise (e.g. for Scala 2.11 closures) return None. + * + * @param maybeClosure the closure to check. + */ + def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = { + def isClosureCandidate(cls: Class[_]): Boolean = { + // TODO: maybe lift this restriction to support other functional interfaces in the future + val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala + implementedInterfaces.exists(_.getName.startsWith("scala.Function")) + } + + maybeClosure.getClass match { + // shortcut the fast check: + // 1. indylambda closure classes are generated by Java's LambdaMetafactory, and they're + // always synthetic. + // 2. We only care about Serializable closures, so let's check that as well + case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => None + + case c if isClosureCandidate(c) => + try { + Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure) + } catch { + case e: Exception => + logDebug("The given reference is not an indylambda Scala closure.", e) + None + } + + case _ => None + } + } + + def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = { + lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic && + lambdaProxy.getImplMethodName.contains("$anonfun$") + } + + def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[SerializedLambda] + } + + /** + * Check if the handle represents the LambdaMetafactory that indylambda Scala closures + * use for creating the lambda class and getting a closure instance. + */ + def isLambdaMetafactory(bsmHandle: Handle): Boolean = { + bsmHandle.getOwner == LambdaMetafactoryClassName && + bsmHandle.getName == LambdaMetafactoryMethodName && + bsmHandle.getDesc == LambdaMetafactoryMethodDesc + } + + /** + * Check if the handle represents a target method that is: + * - a STATIC method that implements a Scala lambda body in the indylambda style + * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as + * the owning class. + * Returns true if both criteria above are met. + */ + def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = { + handle.getTag == H_INVOKESTATIC && + handle.getName.contains("$anonfun$") && + handle.getOwner == ownerInternalName && + handle.getDesc.startsWith(s"(L$ownerInternalName;") + } + + /** + * Check if the callee of a call site is a inner class constructor. + * - A constructor has to be invoked via INVOKESPECIAL + * - A constructor's internal name is "<init>" and the return type is "V" (void) + * - An inner class' first argument in the signature has to be a reference to the + * enclosing "this", aka `$outer` in Scala. + */ + def isInnerClassCtorCapturingOuter( + op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = { + op == INVOKESPECIAL && name == "<init>" && desc.startsWith(s"(L$callerInternalName;") + } + + /** + * Scans an indylambda Scala closure, along with its lexically nested closures, and populate + * the accessed fields info on which fields on the outer object are accessed. + * + * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + FieldAccessFinder fused + * into one for processing indylambda closures. The traversal order along the call graph is the + * same for all three combined, so they can be fused together easily while maintaining the same + * ordering as the existing implementation. + * + * Precondition: this function expects the `accessedFields` to be populated with all known + * outer classes and their super classes to be in the map as keys, e.g. + * initializing via ClosureCleaner.initAccessedFields. + */ + // scalastyle:off line.size.limit + // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+: + // val topLevelValue = "someValue"; val closure = (j: Int) => { + // class InnerFoo { + // val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + // } + // val innerFoo = new InnerFoo + // (1 to j).flatMap(innerFoo.innerClosure) + // } + // sc.parallelize(0 to 2).map(closure).collect + // + // produces the following trace-level logs: + // (slightly simplified: + // - omitting the "ignoring ..." lines; + // - "$iw" is actually "$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw"; + // - "invokedynamic" lines are simplified to just show the name+desc, omitting the bsm info) + // Cleaning indylambda closure: $anonfun$closure$1$adapted + // scanning $iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // found inner class $iw$InnerFoo$1 + // found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // invokedynamic: lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, bsm...) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw.topLevelValue()Ljava/lang/String; + // found field access topLevelValue on $iw + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.<init>(L$iw;)V + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // + fields accessed by starting closure: 2 classes + // (class java.lang.Object,Set()) + // (class $iw,Set(topLevelValue)) + // + cloning instance of REPL class $iw + // +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++ + // + // scalastyle:on line.size.limit + def findAccessedFields( + lambdaProxy: SerializedLambda, + lambdaClassLoader: ClassLoader, + accessedFields: Map[Class[_], Set[String]], + findTransitively: Boolean): Unit = { + + // We may need to visit the same class multiple times for different methods on it, and we'll + // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode. + val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)] + val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode] + def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = { + val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, { + val classExternalName = classInternalName.replace('/', '.') + // scalastyle:off classforname + val clazz = Class.forName(classExternalName, false, lambdaClassLoader) + // scalastyle:on classforname + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + + for (m <- classNode.methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + + (clazz, classNode) + }) + classInfo + } + + val implClassInternalName = lambdaProxy.getImplClass + val (implClass, _) = getOrUpdateClassInfo(implClassInternalName) + + val implMethodId = MethodIdentifier( + implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature) + + // The set of classes that we would consider following the calls into. + // Candidates are: known outer class which happens to be the starting closure's impl class, + // and all inner classes discovered below. + val trackedClassesByInternalName = Map[String, Class[_]](implClassInternalName -> implClass) Review comment: `Map` -> `Set`? It seems classes in map values are not used in the code below. ########## File path: core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala ########## @@ -414,6 +434,296 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] object IndylambdaScalaClosures extends Logging { + // internal name of java.lang.invoke.LambdaMetafactory + val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory" + // the method that Scala indylambda use for bootstrap method + val LambdaMetafactoryMethodName = "altMetafactory" + val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" + + "Ljava/lang/invoke/CallSite;" + + /** + * Check if the given reference is a indylambda style Scala closure. + * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization proxy + * (SerializedLambda) of the closure; + * otherwise (e.g. for Scala 2.11 closures) return None. + * + * @param maybeClosure the closure to check. + */ + def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = { + def isClosureCandidate(cls: Class[_]): Boolean = { + // TODO: maybe lift this restriction to support other functional interfaces in the future + val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala + implementedInterfaces.exists(_.getName.startsWith("scala.Function")) + } + + maybeClosure.getClass match { + // shortcut the fast check: + // 1. indylambda closure classes are generated by Java's LambdaMetafactory, and they're + // always synthetic. + // 2. We only care about Serializable closures, so let's check that as well + case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => None + + case c if isClosureCandidate(c) => + try { + Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure) + } catch { + case e: Exception => + logDebug("The given reference is not an indylambda Scala closure.", e) + None + } + + case _ => None + } + } + + def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = { + lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic && + lambdaProxy.getImplMethodName.contains("$anonfun$") + } + + def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[SerializedLambda] + } + + /** + * Check if the handle represents the LambdaMetafactory that indylambda Scala closures + * use for creating the lambda class and getting a closure instance. + */ + def isLambdaMetafactory(bsmHandle: Handle): Boolean = { + bsmHandle.getOwner == LambdaMetafactoryClassName && + bsmHandle.getName == LambdaMetafactoryMethodName && + bsmHandle.getDesc == LambdaMetafactoryMethodDesc + } + + /** + * Check if the handle represents a target method that is: + * - a STATIC method that implements a Scala lambda body in the indylambda style + * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as + * the owning class. + * Returns true if both criteria above are met. + */ + def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = { + handle.getTag == H_INVOKESTATIC && + handle.getName.contains("$anonfun$") && + handle.getOwner == ownerInternalName && + handle.getDesc.startsWith(s"(L$ownerInternalName;") + } + + /** + * Check if the callee of a call site is a inner class constructor. + * - A constructor has to be invoked via INVOKESPECIAL + * - A constructor's internal name is "<init>" and the return type is "V" (void) + * - An inner class' first argument in the signature has to be a reference to the + * enclosing "this", aka `$outer` in Scala. + */ + def isInnerClassCtorCapturingOuter( + op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = { + op == INVOKESPECIAL && name == "<init>" && desc.startsWith(s"(L$callerInternalName;") + } + + /** + * Scans an indylambda Scala closure, along with its lexically nested closures, and populate + * the accessed fields info on which fields on the outer object are accessed. + * + * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + FieldAccessFinder fused + * into one for processing indylambda closures. The traversal order along the call graph is the + * same for all three combined, so they can be fused together easily while maintaining the same + * ordering as the existing implementation. + * + * Precondition: this function expects the `accessedFields` to be populated with all known + * outer classes and their super classes to be in the map as keys, e.g. + * initializing via ClosureCleaner.initAccessedFields. + */ + // scalastyle:off line.size.limit + // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+: + // val topLevelValue = "someValue"; val closure = (j: Int) => { + // class InnerFoo { + // val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + // } + // val innerFoo = new InnerFoo + // (1 to j).flatMap(innerFoo.innerClosure) + // } + // sc.parallelize(0 to 2).map(closure).collect + // + // produces the following trace-level logs: + // (slightly simplified: + // - omitting the "ignoring ..." lines; + // - "$iw" is actually "$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw"; + // - "invokedynamic" lines are simplified to just show the name+desc, omitting the bsm info) + // Cleaning indylambda closure: $anonfun$closure$1$adapted + // scanning $iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // found inner class $iw$InnerFoo$1 + // found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // invokedynamic: lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, bsm...) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw.topLevelValue()Ljava/lang/String; + // found field access topLevelValue on $iw + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.<init>(L$iw;)V + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // + fields accessed by starting closure: 2 classes + // (class java.lang.Object,Set()) + // (class $iw,Set(topLevelValue)) + // + cloning instance of REPL class $iw + // +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++ + // + // scalastyle:on line.size.limit + def findAccessedFields( + lambdaProxy: SerializedLambda, + lambdaClassLoader: ClassLoader, + accessedFields: Map[Class[_], Set[String]], + findTransitively: Boolean): Unit = { + + // We may need to visit the same class multiple times for different methods on it, and we'll + // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode. + val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)] + val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode] + def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = { + val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, { + val classExternalName = classInternalName.replace('/', '.') + // scalastyle:off classforname + val clazz = Class.forName(classExternalName, false, lambdaClassLoader) + // scalastyle:on classforname + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + + for (m <- classNode.methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + + (clazz, classNode) + }) + classInfo + } + + val implClassInternalName = lambdaProxy.getImplClass + val (implClass, _) = getOrUpdateClassInfo(implClassInternalName) + + val implMethodId = MethodIdentifier( + implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature) + + // The set of classes that we would consider following the calls into. + // Candidates are: known outer class which happens to be the starting closure's impl class, + // and all inner classes discovered below. + val trackedClassesByInternalName = Map[String, Class[_]](implClassInternalName -> implClass) + + // Depth-first search for inner closures and track the fields that were accessed in them. + // Start from the lambda body's implementation method, follow method invocations + val visited = Set.empty[MethodIdentifier[_]] + val stack = Stack[MethodIdentifier[_]](implMethodId) + def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = { + if (!visited.contains(methodId)) { + stack.push(methodId) + } + } + + while (!stack.isEmpty) { + val currentId = stack.pop + visited += currentId + + val currentClass = currentId.cls + val currentMethodNode = methodNodeById(currentId) + logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}") + currentMethodNode.accept(new MethodVisitor(ASM7) { + val currentClassName = currentClass.getName + val currentClassInternalName = currentClassName.replace('.', '/') + + // Find and update the accessedFields info. Only fields on known outer classes are tracked. + // This is the FieldAccessFinder equivalent. + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = { + if (op == GETFIELD || op == PUTFIELD) { + val ownerExternalName = owner.replace('/', '.') + for (cl <- accessedFields.keys if cl.getName == ownerExternalName) { + logTrace(s" found field access $name on $ownerExternalName") + accessedFields(cl) += name + } + } + } + + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { + val ownerExternalName = owner.replace('/', '.') + if (owner == currentClassInternalName) { + logTrace(s" found intra class call to $ownerExternalName.$name$desc") + // could be invoking a helper method or a field accessor method, just follow it. + pushIfNotVisited(MethodIdentifier(currentClass, name, desc)) + } else if (isInnerClassCtorCapturingOuter( + op, owner, name, desc, currentClassInternalName)) { + // Discover inner classes. + // This this the InnerClassFinder equivalent for inner classes, which still use the + // `$outer` chain. So this is NOT controlled by the `findTransitively` flag. + logTrace(s" found inner class $ownerExternalName") + val innerClassInfo = getOrUpdateClassInfo(owner) + val innerClass = innerClassInfo._1 + val innerClassNode = innerClassInfo._2 + trackedClassesByInternalName(owner) = innerClass + // We need to visit all methods on the inner class so that we don't missing anything. + for (m <- innerClassNode.methods.asScala) { Review comment: How about logging found methods here (e.g., `logTrace(s" found method ${m.name}${m.desc}")`), too? The additional log looks like this; ``` 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: scanning $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw.$anonfun$closure$1(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw;I)Lscala/collection/immutable/IndexedSeq; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found inner class $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method innerClosure()Lscala/Function1; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method $anonfun$innerClosure$2(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1;I)Ljava/lang/String; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method $anonfun$innerClosure$1(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method <init>(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw;Ljava/lang/String;)V 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method $anonfun$innerClosure$2$adapted(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method $anonfun$innerClosure$1$adapted(L$line14/$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found method $deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: ignoring call to scala.Predef$.intWrapper(I)I 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: ignoring call to scala.runtime.RichInt$.to$extension0(II)Lscala/collection/immutable/Range$Inclusive; 20/05/09 08:33:01 TRACE IndylambdaScalaClosures: found call to outer $line14.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$InnerFoo$1.innerClosure()Lscala/Function1; ... ``` ########## File path: core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala ########## @@ -414,6 +434,296 @@ private[spark] object ClosureCleaner extends Logging { } } +private[spark] object IndylambdaScalaClosures extends Logging { + // internal name of java.lang.invoke.LambdaMetafactory + val LambdaMetafactoryClassName = "java/lang/invoke/LambdaMetafactory" + // the method that Scala indylambda use for bootstrap method + val LambdaMetafactoryMethodName = "altMetafactory" + val LambdaMetafactoryMethodDesc = "(Ljava/lang/invoke/MethodHandles$Lookup;" + + "Ljava/lang/String;Ljava/lang/invoke/MethodType;[Ljava/lang/Object;)" + + "Ljava/lang/invoke/CallSite;" + + /** + * Check if the given reference is a indylambda style Scala closure. + * If so (e.g. for Scala 2.12+ closures), return a non-empty serialization proxy + * (SerializedLambda) of the closure; + * otherwise (e.g. for Scala 2.11 closures) return None. + * + * @param maybeClosure the closure to check. + */ + def getSerializationProxy(maybeClosure: AnyRef): Option[SerializedLambda] = { + def isClosureCandidate(cls: Class[_]): Boolean = { + // TODO: maybe lift this restriction to support other functional interfaces in the future + val implementedInterfaces = ClassUtils.getAllInterfaces(cls).asScala + implementedInterfaces.exists(_.getName.startsWith("scala.Function")) + } + + maybeClosure.getClass match { + // shortcut the fast check: + // 1. indylambda closure classes are generated by Java's LambdaMetafactory, and they're + // always synthetic. + // 2. We only care about Serializable closures, so let's check that as well + case c if !c.isSynthetic || !maybeClosure.isInstanceOf[Serializable] => None + + case c if isClosureCandidate(c) => + try { + Option(inspect(maybeClosure)).filter(isIndylambdaScalaClosure) + } catch { + case e: Exception => + logDebug("The given reference is not an indylambda Scala closure.", e) + None + } + + case _ => None + } + } + + def isIndylambdaScalaClosure(lambdaProxy: SerializedLambda): Boolean = { + lambdaProxy.getImplMethodKind == MethodHandleInfo.REF_invokeStatic && + lambdaProxy.getImplMethodName.contains("$anonfun$") + } + + def inspect(closure: AnyRef): SerializedLambda = { + val writeReplace = closure.getClass.getDeclaredMethod("writeReplace") + writeReplace.setAccessible(true) + writeReplace.invoke(closure).asInstanceOf[SerializedLambda] + } + + /** + * Check if the handle represents the LambdaMetafactory that indylambda Scala closures + * use for creating the lambda class and getting a closure instance. + */ + def isLambdaMetafactory(bsmHandle: Handle): Boolean = { + bsmHandle.getOwner == LambdaMetafactoryClassName && + bsmHandle.getName == LambdaMetafactoryMethodName && + bsmHandle.getDesc == LambdaMetafactoryMethodDesc + } + + /** + * Check if the handle represents a target method that is: + * - a STATIC method that implements a Scala lambda body in the indylambda style + * - captures the enclosing `this`, i.e. the first argument is a reference to the same type as + * the owning class. + * Returns true if both criteria above are met. + */ + def isLambdaBodyCapturingOuter(handle: Handle, ownerInternalName: String): Boolean = { + handle.getTag == H_INVOKESTATIC && + handle.getName.contains("$anonfun$") && + handle.getOwner == ownerInternalName && + handle.getDesc.startsWith(s"(L$ownerInternalName;") + } + + /** + * Check if the callee of a call site is a inner class constructor. + * - A constructor has to be invoked via INVOKESPECIAL + * - A constructor's internal name is "<init>" and the return type is "V" (void) + * - An inner class' first argument in the signature has to be a reference to the + * enclosing "this", aka `$outer` in Scala. + */ + def isInnerClassCtorCapturingOuter( + op: Int, owner: String, name: String, desc: String, callerInternalName: String): Boolean = { + op == INVOKESPECIAL && name == "<init>" && desc.startsWith(s"(L$callerInternalName;") + } + + /** + * Scans an indylambda Scala closure, along with its lexically nested closures, and populate + * the accessed fields info on which fields on the outer object are accessed. + * + * This is equivalent to getInnerClosureClasses() + InnerClosureFinder + FieldAccessFinder fused + * into one for processing indylambda closures. The traversal order along the call graph is the + * same for all three combined, so they can be fused together easily while maintaining the same + * ordering as the existing implementation. + * + * Precondition: this function expects the `accessedFields` to be populated with all known + * outer classes and their super classes to be in the map as keys, e.g. + * initializing via ClosureCleaner.initAccessedFields. + */ + // scalastyle:off line.size.limit + // Example: run the following code snippet in a Spark Shell w/ Scala 2.12+: + // val topLevelValue = "someValue"; val closure = (j: Int) => { + // class InnerFoo { + // val innerClosure = (x: Int) => (1 to x).map { y => y + topLevelValue } + // } + // val innerFoo = new InnerFoo + // (1 to j).flatMap(innerFoo.innerClosure) + // } + // sc.parallelize(0 to 2).map(closure).collect + // + // produces the following trace-level logs: + // (slightly simplified: + // - omitting the "ignoring ..." lines; + // - "$iw" is actually "$line16.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw"; + // - "invokedynamic" lines are simplified to just show the name+desc, omitting the bsm info) + // Cleaning indylambda closure: $anonfun$closure$1$adapted + // scanning $iw.$anonfun$closure$1$adapted(L$iw;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw.$anonfun$closure$1(L$iw;I)Lscala/collection/immutable/IndexedSeq; + // found inner class $iw$InnerFoo$1 + // found call to outer $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // scanning $iw$InnerFoo$1.$deserializeLambda$(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object; + // invokedynamic: lambdaDeserialize(Ljava/lang/invoke/SerializedLambda;)Ljava/lang/Object;, bsm...) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw.topLevelValue()Ljava/lang/String; + // found field access topLevelValue on $iw + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; + // found intra class call to $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // scanning $iw$InnerFoo$1.<init>(L$iw;)V + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$1$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Lscala/collection/immutable/IndexedSeq; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$1(L$iw$InnerFoo$1;I)Lscala/collection/immutable/IndexedSeq; + // invokedynamic: apply(L$iw$InnerFoo$1;)Lscala/Function1;, bsm...) + // found inner closure $iw$InnerFoo$1.$anonfun$innerClosure$2$adapted(L$iw$InnerFoo$1;Ljava/lang/Object;)Ljava/lang/String; (6) + // scanning $iw$InnerFoo$1.$anonfun$innerClosure$2(L$iw$InnerFoo$1;I)Ljava/lang/String; + // found call to outer $iw.topLevelValue()Ljava/lang/String; + // scanning $iw$InnerFoo$1.innerClosure()Lscala/Function1; + // + fields accessed by starting closure: 2 classes + // (class java.lang.Object,Set()) + // (class $iw,Set(topLevelValue)) + // + cloning instance of REPL class $iw + // +++ indylambda closure ($anonfun$closure$1$adapted) is now cleaned +++ + // + // scalastyle:on line.size.limit + def findAccessedFields( + lambdaProxy: SerializedLambda, + lambdaClassLoader: ClassLoader, + accessedFields: Map[Class[_], Set[String]], + findTransitively: Boolean): Unit = { + + // We may need to visit the same class multiple times for different methods on it, and we'll + // need to lookup by name. So we use ASM's Tree API and cache the ClassNode/MethodNode. + val classInfoByInternalName = Map.empty[String, (Class[_], ClassNode)] + val methodNodeById = Map.empty[MethodIdentifier[_], MethodNode] + def getOrUpdateClassInfo(classInternalName: String): (Class[_], ClassNode) = { + val classInfo = classInfoByInternalName.getOrElseUpdate(classInternalName, { + val classExternalName = classInternalName.replace('/', '.') + // scalastyle:off classforname + val clazz = Class.forName(classExternalName, false, lambdaClassLoader) + // scalastyle:on classforname + val classNode = new ClassNode() + val classReader = ClosureCleaner.getClassReader(clazz) + classReader.accept(classNode, 0) + + for (m <- classNode.methods.asScala) { + methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m + } + + (clazz, classNode) + }) + classInfo + } + + val implClassInternalName = lambdaProxy.getImplClass + val (implClass, _) = getOrUpdateClassInfo(implClassInternalName) + + val implMethodId = MethodIdentifier( + implClass, lambdaProxy.getImplMethodName, lambdaProxy.getImplMethodSignature) + + // The set of classes that we would consider following the calls into. + // Candidates are: known outer class which happens to be the starting closure's impl class, + // and all inner classes discovered below. + val trackedClassesByInternalName = Map[String, Class[_]](implClassInternalName -> implClass) + + // Depth-first search for inner closures and track the fields that were accessed in them. + // Start from the lambda body's implementation method, follow method invocations + val visited = Set.empty[MethodIdentifier[_]] + val stack = Stack[MethodIdentifier[_]](implMethodId) + def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = { + if (!visited.contains(methodId)) { + stack.push(methodId) + } + } + + while (!stack.isEmpty) { + val currentId = stack.pop + visited += currentId + + val currentClass = currentId.cls + val currentMethodNode = methodNodeById(currentId) + logTrace(s" scanning ${currentId.cls.getName}.${currentId.name}${currentId.desc}") + currentMethodNode.accept(new MethodVisitor(ASM7) { + val currentClassName = currentClass.getName + val currentClassInternalName = currentClassName.replace('.', '/') + + // Find and update the accessedFields info. Only fields on known outer classes are tracked. + // This is the FieldAccessFinder equivalent. + override def visitFieldInsn(op: Int, owner: String, name: String, desc: String): Unit = { + if (op == GETFIELD || op == PUTFIELD) { + val ownerExternalName = owner.replace('/', '.') + for (cl <- accessedFields.keys if cl.getName == ownerExternalName) { + logTrace(s" found field access $name on $ownerExternalName") + accessedFields(cl) += name + } + } + } + + override def visitMethodInsn( + op: Int, owner: String, name: String, desc: String, itf: Boolean): Unit = { + val ownerExternalName = owner.replace('/', '.') + if (owner == currentClassInternalName) { + logTrace(s" found intra class call to $ownerExternalName.$name$desc") + // could be invoking a helper method or a field accessor method, just follow it. + pushIfNotVisited(MethodIdentifier(currentClass, name, desc)) + } else if (isInnerClassCtorCapturingOuter( + op, owner, name, desc, currentClassInternalName)) { + // Discover inner classes. + // This this the InnerClassFinder equivalent for inner classes, which still use the + // `$outer` chain. So this is NOT controlled by the `findTransitively` flag. + logTrace(s" found inner class $ownerExternalName") + val innerClassInfo = getOrUpdateClassInfo(owner) + val innerClass = innerClassInfo._1 + val innerClassNode = innerClassInfo._2 + trackedClassesByInternalName(owner) = innerClass + // We need to visit all methods on the inner class so that we don't missing anything. + for (m <- innerClassNode.methods.asScala) { + pushIfNotVisited(MethodIdentifier(innerClass, m.name, m.desc)) + } + } else if (findTransitively && trackedClassesByInternalName.contains(owner)) { + logTrace(s" found call to outer $ownerExternalName.$name$desc") Review comment: Is this log only for outer cases? In [the log](https://github.com/apache/spark/pull/28463/files#diff-4928e25ed331cc478162f750f53652e2R562) shown in the comment above, it seems to capture inner cases, too? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org