This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1d562904e4e [SPARK-44795][CONNECT] CodeGenerator Cache should be 
classloader specific
1d562904e4e is described below

commit 1d562904e4e75aec3ea8d4999ede0183fda326c7
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Tue Aug 15 03:10:42 2023 +0200

    [SPARK-44795][CONNECT] CodeGenerator Cache should be classloader specific
    
    ### What changes were proposed in this pull request?
    When you currently use a REPL generated class in a UDF you can get an error 
saying that that class is not equal to that class. This error is thrown in a 
code generated class. The problem is that the classes have been loaded by 
different classloaders. We cache generated code and use the textual code as the 
string. The problem with this is that in Spark Connect users are free in 
supplying user classes that can have arbitrary names, a name can point to an 
entirely different class, or it  [...]
    
    There are roughly two ways how this problem can arise:
    1. Two sessions use the same class names. This is particularly easy when 
you use the REPL because this  always generates the same names.
    2. You run in single process mode. In this case wholestage codegen will 
test compile the class using a different classloader then the 'executor', while 
sharing the same code generator cache.
    
    ### Why are the changes needed?
    We want to be able to use REPL (and other)
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    I added a test to the `ReplE2ESuite`.
    
    Closes #42478 from hvanhovell/SPARK-44795.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../spark/sql/application/ReplE2ESuite.scala       |  9 ++++
 .../spark/sql/catalyst/encoders/OuterScopes.scala  | 49 +++++++++++++---------
 .../expressions/codegen/CodeGenerator.scala        | 30 ++++++-------
 3 files changed, 54 insertions(+), 34 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 0e69b5afa45..0cab66eef3d 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -276,4 +276,13 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     val output = runCommandsInShell(input)
     assertContains("Array[MyTestClass] = Array(MyTestClass(1), 
MyTestClass(3))", output)
   }
+
+  test("REPL class in UDF") {
+    val input = """
+        |case class MyTestClass(value: Int)
+        |spark.range(2).map(i => MyTestClass(i.toInt)).collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[MyTestClass] = Array(MyTestClass(0), 
MyTestClass(1))", output)
+  }
 }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index c2ac504c846..6c10e8ece80 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -26,28 +26,9 @@ import org.apache.spark.util.SparkClassUtils
 
 object OuterScopes {
   private[this] val queue = new ReferenceQueue[AnyRef]
-  private class HashableWeakReference(v: AnyRef) extends 
WeakReference[AnyRef](v, queue) {
-    private[this] val hash = v.hashCode()
-    override def hashCode(): Int = hash
-    override def equals(obj: Any): Boolean = {
-      obj match {
-        case other: HashableWeakReference =>
-          // Note that referential equality is used to identify & purge
-          // references from the map whose' referent went out of scope.
-          if (this eq other) {
-            true
-          } else {
-            val referent = get()
-            val otherReferent = other.get()
-            referent != null && otherReferent != null && 
Objects.equals(referent, otherReferent)
-          }
-        case _ => false
-      }
-    }
-  }
 
   private def classLoaderRef(c: Class[_]): HashableWeakReference = {
-    new HashableWeakReference(c.getClassLoader)
+    new HashableWeakReference(c.getClassLoader, queue)
   }
 
   private[this] val outerScopes = {
@@ -154,3 +135,31 @@ object OuterScopes {
   // e.g. `ammonite.$sess.cmd8$Helper$Foo` -> 
`ammonite.$sess.cmd8.instance.Foo`
   private[this] val AmmoniteREPLClass = 
"""^(ammonite\.\$sess\.cmd(?:\d+)\$).*""".r
 }
+
+/**
+ * A [[WeakReference]] that has a stable hash-key. When the referent is still 
alive we will use
+ * the referent for equality, once it is dead it we will fallback to 
referential equality. This
+ * way you can still do lookups in a map when the referent is alive, and are 
capable of removing
+ * dead entries after GC (using a [[ReferenceQueue]]).
+ */
+private[catalyst] class HashableWeakReference(v: AnyRef, queue: 
ReferenceQueue[AnyRef])
+  extends WeakReference[AnyRef](v, queue) {
+  def this(v: AnyRef) = this(v, null)
+  private[this] val hash = v.hashCode()
+  override def hashCode(): Int = hash
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: HashableWeakReference =>
+        // Note that referential equality is used to identify & purge
+        // references from the map whose' referent went out of scope.
+        if (this eq other) {
+          true
+        } else {
+          val referent = get()
+          val otherReferent = other.get()
+          referent != null && otherReferent != null && 
Objects.equals(referent, otherReferent)
+        }
+      case _ => false
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 8d10f6cd295..fe61cc81359 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -35,6 +35,7 @@ import org.apache.spark.executor.InputMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.HashableWeakReference
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.types._
@@ -1439,7 +1440,8 @@ object CodeGenerator extends Logging {
    * @return a pair of a generated class and the bytecode statistics of 
generated functions.
    */
   def compile(code: CodeAndComment): (GeneratedClass, ByteCodeStats) = try {
-    cache.get(code)
+    val classLoaderRef = new 
HashableWeakReference(Utils.getContextOrSparkClassLoader)
+    cache.get((classLoaderRef, code))
   } catch {
     // Cache.get() may wrap the original exception. See the following URL
     // https://guava.dev/releases/14.0.1/api/docs/com/google/common/cache/
@@ -1581,20 +1583,20 @@ object CodeGenerator extends Logging {
    * aborted. See [[NonFateSharingCache]] for more details.
    */
   private val cache = {
-    def loadFunc: CodeAndComment => (GeneratedClass, ByteCodeStats) = code => {
-      val startTime = System.nanoTime()
-      val result = doCompile(code)
-      val endTime = System.nanoTime()
-      val duration = endTime - startTime
-      val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS
-      CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length)
-      CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong)
-      logInfo(s"Code generated in $timeMs ms")
-      _compileTime.add(duration)
-      result
+    val loadFunc: ((HashableWeakReference, CodeAndComment)) => 
(GeneratedClass, ByteCodeStats) = {
+      case (_, code) =>
+        val startTime = System.nanoTime()
+        val result = doCompile(code)
+        val endTime = System.nanoTime()
+        val duration = endTime - startTime
+        val timeMs: Double = duration.toDouble / NANOS_PER_MILLIS
+        CodegenMetrics.METRIC_SOURCE_CODE_SIZE.update(code.body.length)
+        CodegenMetrics.METRIC_COMPILATION_TIME.update(timeMs.toLong)
+        logInfo(s"Code generated in $timeMs ms")
+        _compileTime.add(duration)
+        result
     }
-    NonFateSharingCache[CodeAndComment, (GeneratedClass, ByteCodeStats)](
-      loadFunc, SQLConf.get.codegenCacheMaxEntries)
+    NonFateSharingCache(loadFunc, SQLConf.get.codegenCacheMaxEntries)
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to