This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dcf3d582293 [SPARK-44791][CONNECT] Make ArrowDeserializer work with 
REPL generated classes
dcf3d582293 is described below

commit dcf3d582293c3dbb3820d12fa15b41e8bd5fe6ad
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Mon Aug 14 02:38:54 2023 +0200

    [SPARK-44791][CONNECT] Make ArrowDeserializer work with REPL generated 
classes
    
    ### What changes were proposed in this pull request?
    Connects arrow deserialization currently does not work with REPL generated 
classes. For example the following code would fail:
    ```scala
    case class MyTestClass(value: Int) {
      override def toString: String = value.toString
    }
    spark.range(10).map(i => MyTestClass(i.toInt)).collect()
    ```
    
    The problem is that for instantiation of the `MyTestClass` class we need 
the instance of the class that it was defined in (its outerscope). In Spark we 
have a mechanism called `OuterScopes` to register these instances in. The 
`ArrowDeserializer` was not resolving this outer instance. This PR fixes this.
    
    We have a similar issue on the executor/driver side. This will be fixed in 
a different PR.
    
    ### Why are the changes needed?
    It is a bug.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    I have added tests to `ReplE2Esuite` and to the `ArrowEncoderSuite`.
    
    Closes #42473 from hvanhovell/SPARK-44791.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../org/apache/spark/util/SparkClassUtils.scala    | 28 +++++++
 .../connect/client/arrow/ArrowDeserializer.scala   | 14 +++-
 .../spark/sql/application/ReplE2ESuite.scala       | 33 ++++-----
 .../connect/client/arrow/ArrowEncoderSuite.scala   | 12 ++-
 .../main/scala/org/apache/spark/util/Utils.scala   | 28 -------
 .../spark/sql/catalyst/encoders/OuterScopes.scala  | 85 +++++++++++++++++-----
 .../apache/spark/sql/errors/ExecutionErrors.scala  |  7 ++
 .../spark/sql/errors/QueryExecutionErrors.scala    |  7 --
 8 files changed, 138 insertions(+), 76 deletions(-)

diff --git 
a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala 
b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
index a237869aef3..679d546d04c 100644
--- a/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkClassUtils.scala
@@ -50,6 +50,34 @@ trait SparkClassUtils {
   def classIsLoadable(clazz: String): Boolean = {
     Try { classForName(clazz, initialize = false) }.isSuccess
   }
+
+  /**
+   * Returns true if and only if the underlying class is a member class.
+   *
+   * Note: jdk8u throws a "Malformed class name" error if a given class is a 
deeply-nested
+   * inner class (See SPARK-34607 for details). This issue has already been 
fixed in jdk9+, so
+   * we can remove this helper method safely if we drop the support of jdk8u.
+   */
+  def isMemberClass(cls: Class[_]): Boolean = {
+    try {
+      cls.isMemberClass
+    } catch {
+      case _: InternalError =>
+        // We emulate jdk8u `Class.isMemberClass` below:
+        //   public boolean isMemberClass() {
+        //     return getSimpleBinaryName() != null && 
!isLocalOrAnonymousClass();
+        //   }
+        // `getSimpleBinaryName()` returns null if a given class is a 
top-level class,
+        // so we replace it with `cls.getEnclosingClass != null`. The second 
condition checks
+        // if a given class is not a local or an anonymous class, so we 
replace it with
+        // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()` 
return a value
+        // only in either case (JVM Spec 4.8.6).
+        //
+        // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first,
+        // we reorder the conditions to follow it.
+        cls.getEnclosingMethod == null && cls.getEnclosingClass != null
+    }
+  }
 }
 
 object SparkClassUtils extends SparkClassUtils
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 55dd640f1b6..82086b9d47a 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader
 import org.apache.arrow.vector.util.Text
 
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.connect.client.CloseableIterator
@@ -290,15 +290,23 @@ object ArrowDeserializers {
 
       case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
         // We should try to make this work with MethodHandles.
+        val outer = 
Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq
         val Some(constructor) =
-          ScalaReflection.findConstructor(tag.runtimeClass, 
fields.map(_.enc.clsTag.runtimeClass))
+          ScalaReflection.findConstructor(
+            tag.runtimeClass,
+            outer.map(_.getClass) ++ fields.map(_.enc.clsTag.runtimeClass))
         val deserializers = if (isTuple(tag.runtimeClass)) {
           fields.zip(vectors).map { case (field, vector) =>
             deserializerFor(field.enc, vector, timeZoneId)
           }
         } else {
+          val outerDeserializer = outer.map { value =>
+            new Deserializer[Any] {
+              override def get(i: Int): Any = value
+            }
+          }
           val lookup = createFieldLookup(vectors)
-          fields.map { field =>
+          outerDeserializer ++ fields.map { field =>
             deserializerFor(field.enc, lookup(field.name), timeZoneId)
           }
         }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 0c19b8b7df1..0e69b5afa45 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -134,20 +134,6 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("Array[Int] = Array(19, 24, 29, 34, 39)", output)
   }
 
-  // SPARK-43198: Switching REPL to CodeClass generation mode causes UDFs 
defined through lambda
-  // expressions to hit deserialization issues.
-  // TODO(SPARK-43227): Enable test after fixing deserialization issue.
-  ignore("UDF containing lambda expression") {
-    val input = """
-        |class A(x: Int) { def get = x * 20 + 5 }
-        |val dummyUdf = (x: Int) => new A(x).get
-        |val myUdf = udf(dummyUdf)
-        |spark.range(5).select(myUdf(col("id"))).as[Int].collect()
-      """.stripMargin
-    val output = runCommandsInShell(input)
-    assertContains("Array[Int] = Array(5, 25, 45, 65, 85)", output)
-  }
-
   test("UDF containing in-place lambda") {
     val input = """
         |class A(x: Int) { def get = x * 42 + 5 }
@@ -238,9 +224,8 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
   }
 
   test("UDF Registration") {
-    // TODO SPARK-44449 make this long again when upcasting is in.
     val input = """
-        |class A(x: Int) { def get: Long = x * 100 }
+        |class A(x: Int) { def get = x * 100 }
         |val myUdf = udf((x: Int) => new A(x).get)
         |spark.udf.register("dummyUdf", myUdf)
         |spark.sql("select dummyUdf(id) from range(5)").as[Long].collect()
@@ -250,9 +235,8 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
   }
 
   test("UDF closure registration") {
-    // TODO SPARK-44449 make this int again when upcasting is in.
     val input = """
-        |class A(x: Int) { def get: Long = x * 15 }
+        |class A(x: Int) { def get = x * 15 }
         |spark.udf.register("directUdf", (x: Int) => new A(x).get)
         |spark.sql("select directUdf(id) from range(5)").as[Long].collect()
       """.stripMargin
@@ -279,4 +263,17 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     val output = runCommandsInShell(input)
     assertContains("Array[org.apache.spark.sql.Row] = Array([id1,1], [id2,16], 
[id3,25])", output)
   }
+
+  test("Collect REPL generated class") {
+    val input = """
+        |case class MyTestClass(value: Int)
+        |spark.range(4).
+        |  filter($"id" % 2 === 1).
+        |  select($"id".cast("int").as("value")).
+        |  as[MyTestClass].
+        |  collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[MyTestClass] = Array(MyTestClass(1), 
MyTestClass(3))", output)
+  }
 }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
index 7a8e8465a70..2a499cc548f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala
@@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.apache.spark.SparkUnsupportedOperationException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.{DefinedByConstructorParams, 
JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, 
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, 
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, 
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, 
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, 
NullEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, 
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, Primi [...]
 import org.apache.spark.sql.catalyst.encoders.RowEncoder.{encoderFor => 
toRowEncoder}
 import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkStringUtils, 
TimestampFormatter}
@@ -759,6 +759,16 @@ class ArrowEncoderSuite extends ConnectFunSuite with 
BeforeAndAfterAll {
     }
   }
 
+  case class MyTestClass(value: Int)
+  OuterScopes.addOuterScope(this)
+
+  test("REPL generated classes") {
+    val encoder = ScalaReflection.encoderFor[MyTestClass]
+    roundTripAndCheckIdentical(encoder) { () =>
+      Iterator.tabulate(10)(MyTestClass)
+    }
+  }
+
   /* ******************************************************************** *
    * Arrow deserialization upcasting
    * ******************************************************************** */
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala 
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a35fb3c0078..35e99785f74 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2885,34 +2885,6 @@ private[spark] object Utils
     Hex.encodeHexString(secretBytes)
   }
 
-  /**
-   * Returns true if and only if the underlying class is a member class.
-   *
-   * Note: jdk8u throws a "Malformed class name" error if a given class is a 
deeply-nested
-   * inner class (See SPARK-34607 for details). This issue has already been 
fixed in jdk9+, so
-   * we can remove this helper method safely if we drop the support of jdk8u.
-   */
-  def isMemberClass(cls: Class[_]): Boolean = {
-    try {
-      cls.isMemberClass
-    } catch {
-      case _: InternalError =>
-        // We emulate jdk8u `Class.isMemberClass` below:
-        //   public boolean isMemberClass() {
-        //     return getSimpleBinaryName() != null && 
!isLocalOrAnonymousClass();
-        //   }
-        // `getSimpleBinaryName()` returns null if a given class is a 
top-level class,
-        // so we replace it with `cls.getEnclosingClass != null`. The second 
condition checks
-        // if a given class is not a local or an anonymous class, so we 
replace it with
-        // `cls.getEnclosingMethod == null` because `cls.getEnclosingMethod()` 
return a value
-        // only in either case (JVM Spec 4.8.6).
-        //
-        // Note: The newer jdk evaluates `!isLocalOrAnonymousClass()` first,
-        // we reorder the conditions to follow it.
-        cls.getEnclosingMethod == null && cls.getEnclosingClass != null
-    }
-  }
-
   /**
    * Safer than Class obj's getSimpleName which may throw Malformed class name 
error in scala.
    * This method mimics scalatest's getSimpleNameOfAnObjectsClass.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
similarity index 58%
rename from 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
rename to 
sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index 6f7150d8d33..c2ac504c846 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -17,17 +17,53 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import java.util.concurrent.ConcurrentMap
+import java.lang.ref._
+import java.util.Objects
+import java.util.concurrent.ConcurrentHashMap
 
-import com.google.common.collect.MapMaker
-
-import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.util.SparkClassUtils
 
 object OuterScopes {
-  @transient
-  lazy val outerScopes: ConcurrentMap[String, AnyRef] =
-    new MapMaker().weakValues().makeMap()
+  private[this] val queue = new ReferenceQueue[AnyRef]
+  private class HashableWeakReference(v: AnyRef) extends 
WeakReference[AnyRef](v, queue) {
+    private[this] val hash = v.hashCode()
+    override def hashCode(): Int = hash
+    override def equals(obj: Any): Boolean = {
+      obj match {
+        case other: HashableWeakReference =>
+          // Note that referential equality is used to identify & purge
+          // references from the map whose' referent went out of scope.
+          if (this eq other) {
+            true
+          } else {
+            val referent = get()
+            val otherReferent = other.get()
+            referent != null && otherReferent != null && 
Objects.equals(referent, otherReferent)
+          }
+        case _ => false
+      }
+    }
+  }
+
+  private def classLoaderRef(c: Class[_]): HashableWeakReference = {
+    new HashableWeakReference(c.getClassLoader)
+  }
+
+  private[this] val outerScopes = {
+    new ConcurrentHashMap[HashableWeakReference, ConcurrentHashMap[String, 
WeakReference[AnyRef]]]
+  }
+
+  /**
+   * Clean the outer scopes that have been garbage collected.
+   */
+  private def cleanOuterScopes(): Unit = {
+    var entry = queue.poll()
+    while (entry != null) {
+      outerScopes.remove(entry)
+      entry = queue.poll()
+    }
+  }
 
   /**
    * Adds a new outer scope to this context that can be used when 
instantiating an `inner class`
@@ -40,7 +76,11 @@ object OuterScopes {
    * given wrapper class.
    */
   def addOuterScope(outer: AnyRef): Unit = {
-    outerScopes.putIfAbsent(outer.getClass.getName, outer)
+    cleanOuterScopes()
+    val clz = outer.getClass
+    outerScopes
+      .computeIfAbsent(classLoaderRef(clz), _ => new ConcurrentHashMap)
+      .putIfAbsent(clz.getName, new WeakReference(outer))
   }
 
   /**
@@ -49,16 +89,24 @@ object OuterScopes {
    * useful for inner class defined in REPL.
    */
   def getOuterScope(innerCls: Class[_]): () => AnyRef = {
-    assert(Utils.isMemberClass(innerCls))
-    val outerClassName = innerCls.getDeclaringClass.getName
-    val outer = outerScopes.get(outerClassName)
+    if (!SparkClassUtils.isMemberClass(innerCls)) {
+      return null
+    }
+    val outerClass = innerCls.getDeclaringClass
+    val outerClassName = outerClass.getName
+    val outer = Option(outerScopes.get(classLoaderRef(outerClass)))
+      .flatMap(map => Option(map.get(outerClassName)))
+      .map(_.get())
+      .orNull
     if (outer == null) {
       outerClassName match {
         case AmmoniteREPLClass(cellClassName) =>
           () => {
-            val objClass = Utils.classForName(cellClassName)
+            val objClass = SparkClassUtils.classForName(cellClassName)
             val objInstance = objClass.getField("MODULE$").get(null)
-            objClass.getMethod("instance").invoke(objInstance)
+            val obj = objClass.getMethod("instance").invoke(objInstance)
+            addOuterScope(obj)
+            obj
           }
         // If the outer class is generated by REPL, users don't need to 
register it as it has
         // only one instance and there is a way to retrieve it: get the 
`$read` object, call the
@@ -66,10 +114,10 @@ object OuterScopes {
         // method multiply times to get the single instance of the inner most 
`$iw` class.
         case REPLClass(baseClassName) =>
           () => {
-            val objClass = Utils.classForName(baseClassName + "$")
+            val objClass = SparkClassUtils.classForName(baseClassName + "$")
             val objInstance = objClass.getField("MODULE$").get(null)
             val baseInstance = 
objClass.getMethod("INSTANCE").invoke(objInstance)
-            val baseClass = Utils.classForName(baseClassName)
+            val baseClass = SparkClassUtils.classForName(baseClassName)
 
             var getter = iwGetter(baseClass)
             var obj = baseInstance
@@ -79,10 +127,9 @@ object OuterScopes {
             }
 
             if (obj == null) {
-              throw 
QueryExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls)
+              throw 
ExecutionErrors.cannotGetOuterPointerForInnerClassError(innerCls)
             }
-
-            outerScopes.putIfAbsent(outerClassName, obj)
+            addOuterScope(obj)
             obj
           }
         case _ => null
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
index 1e8e0ef5f6a..c8321e81027 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/errors/ExecutionErrors.scala
@@ -206,6 +206,13 @@ private[sql] trait ExecutionErrors extends 
DataTypeErrorsBase {
       errorClass = "_LEGACY_ERROR_TEMP_2021",
       messageParameters = Map("cls" -> cls.toString))
   }
+
+  def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): 
SparkRuntimeException = {
+    new SparkRuntimeException(
+      errorClass = "_LEGACY_ERROR_TEMP_2154",
+      messageParameters = Map(
+        "innerCls" -> innerCls.getName))
+  }
 }
 
 private[sql] object ExecutionErrors extends ExecutionErrors
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 1cc79a92c4c..953d9713c7a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1365,13 +1365,6 @@ private[sql] object QueryExecutionErrors extends 
QueryErrorsBase with ExecutionE
         "objSerializer" -> objSerializer.toString()))
   }
 
-  def cannotGetOuterPointerForInnerClassError(innerCls: Class[_]): 
SparkRuntimeException = {
-    new SparkRuntimeException(
-      errorClass = "_LEGACY_ERROR_TEMP_2154",
-      messageParameters = Map(
-        "innerCls" -> innerCls.getName))
-  }
-
   def unsupportedOperandTypeForSizeFunctionError(
       dataType: DataType): SparkUnsupportedOperationException = {
     new SparkUnsupportedOperationException(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to