This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git

commit 890748873bd8bd72b34d3f907ecdb72a694234c9
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Aug 14 05:32:57 2023 +0200

    Refine solution
---
 .../spark/sql/catalyst/encoders/OuterScopes.scala  | 49 +++++++++++++---------
 .../expressions/codegen/CodeGenerator.scala        | 18 ++------
 2 files changed, 33 insertions(+), 34 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index c2ac504c846..6c10e8ece80 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -26,28 +26,9 @@ import org.apache.spark.util.SparkClassUtils
 
 object OuterScopes {
   private[this] val queue = new ReferenceQueue[AnyRef]
-  private class HashableWeakReference(v: AnyRef) extends 
WeakReference[AnyRef](v, queue) {
-    private[this] val hash = v.hashCode()
-    override def hashCode(): Int = hash
-    override def equals(obj: Any): Boolean = {
-      obj match {
-        case other: HashableWeakReference =>
-          // Note that referential equality is used to identify & purge
-          // references from the map whose' referent went out of scope.
-          if (this eq other) {
-            true
-          } else {
-            val referent = get()
-            val otherReferent = other.get()
-            referent != null && otherReferent != null && 
Objects.equals(referent, otherReferent)
-          }
-        case _ => false
-      }
-    }
-  }
 
   private def classLoaderRef(c: Class[_]): HashableWeakReference = {
-    new HashableWeakReference(c.getClassLoader)
+    new HashableWeakReference(c.getClassLoader, queue)
   }
 
   private[this] val outerScopes = {
@@ -154,3 +135,31 @@ object OuterScopes {
   // e.g. `ammonite.$sess.cmd8$Helper$Foo` -> 
`ammonite.$sess.cmd8.instance.Foo`
   private[this] val AmmoniteREPLClass = 
"""^(ammonite\.\$sess\.cmd(?:\d+)\$).*""".r
 }
+
+/**
+ * A [[WeakReference]] that has a stable hash-key. When the referent is still 
alive we will use
+ * the referent for equality, once it is dead it we will fallback to 
referential equality. This
+ * way you can still do lookups in a map when the referent is alive, and are 
capable of removing
+ * dead entries after GC (using a [[ReferenceQueue]]).
+ */
+private[catalyst] class HashableWeakReference(v: AnyRef, queue: 
ReferenceQueue[AnyRef])
+  extends WeakReference[AnyRef](v, queue) {
+  def this(v: AnyRef) = this(v, null)
+  private[this] val hash = v.hashCode()
+  override def hashCode(): Int = hash
+  override def equals(obj: Any): Boolean = {
+    obj match {
+      case other: HashableWeakReference =>
+        // Note that referential equality is used to identify & purge
+        // references from the map whose' referent went out of scope.
+        if (this eq other) {
+          true
+        } else {
+          val referent = get()
+          val otherReferent = other.get()
+          referent != null && otherReferent != null && 
Objects.equals(referent, otherReferent)
+        }
+      case _ => false
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 59688cae889..fe61cc81359 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.catalyst.expressions.codegen
 
 import java.io.ByteArrayInputStream
-import java.util.UUID
 
 import scala.annotation.tailrec
 import scala.collection.JavaConverters._
@@ -26,7 +25,6 @@ import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
 import scala.util.control.NonFatal
 
-import com.google.common.cache.{CacheBuilder, CacheLoader}
 import com.google.common.util.concurrent.{ExecutionError, 
UncheckedExecutionException}
 import org.codehaus.commons.compiler.{CompileException, 
InternalCompilerException}
 import org.codehaus.janino.ClassBodyEvaluator
@@ -37,6 +35,7 @@ import org.apache.spark.executor.InputMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.metrics.source.CodegenMetrics
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.HashableWeakReference
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.types._
@@ -1441,7 +1440,8 @@ object CodeGenerator extends Logging {
    * @return a pair of a generated class and the bytecode statistics of 
generated functions.
    */
   def compile(code: CodeAndComment): (GeneratedClass, ByteCodeStats) = try {
-    cache.get((classLoaderUUID.get(Utils.getContextOrSparkClassLoader), code))
+    val classLoaderRef = new 
HashableWeakReference(Utils.getContextOrSparkClassLoader)
+    cache.get((classLoaderRef, code))
   } catch {
     // Cache.get() may wrap the original exception. See the following URL
     // https://guava.dev/releases/14.0.1/api/docs/com/google/common/cache/
@@ -1583,7 +1583,7 @@ object CodeGenerator extends Logging {
    * aborted. See [[NonFateSharingCache]] for more details.
    */
   private val cache = {
-    val loadFunc: ((ClassLoaderId, CodeAndComment)) => (GeneratedClass, 
ByteCodeStats) = {
+    val loadFunc: ((HashableWeakReference, CodeAndComment)) => 
(GeneratedClass, ByteCodeStats) = {
       case (_, code) =>
         val startTime = System.nanoTime()
         val result = doCompile(code)
@@ -1599,16 +1599,6 @@ object CodeGenerator extends Logging {
     NonFateSharingCache(loadFunc, SQLConf.get.codegenCacheMaxEntries)
   }
 
-  type ClassLoaderId = String
-  private val classLoaderUUID = {
-    NonFateSharingCache(CacheBuilder.newBuilder()
-      .weakKeys
-      .maximumSize(SQLConf.get.codegenCacheMaxEntries)
-      .build(new CacheLoader[ClassLoader, ClassLoaderId]() {
-        override def load(code: ClassLoader): ClassLoaderId = 
UUID.randomUUID.toString
-      }))
-  }
-
   /**
    * Name of Java primitive data type
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to