This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new cb8c653f3c7 [SPARK-44799][CONNECT] Fix outer scopes resolution on the 
executor side
cb8c653f3c7 is described below

commit cb8c653f3c7f76129deae366608613a968b81264
Author: Herman van Hovell <her...@databricks.com>
AuthorDate: Wed Aug 16 15:28:12 2023 +0200

    [SPARK-44799][CONNECT] Fix outer scopes resolution on the executor side
    
    ### What changes were proposed in this pull request?
    When you define a class in the REPL (with previously defines symbols), for 
example:
    ```scala
    val filePath = "my_path"
    case class MyTestClass(value: Int)
    ```
    This is actually declared inside a command class.In ammonite the structure 
looks like this:
    ```scala
    // First command contains the `filePath`
    object cmd1 {
      val wrapper = new cmd1
      val instance = new command.Helper
    }
    class cmd1 extends Serializable {
      class Helper extends Serializable {
        val filePath = "my_path"
      }
    }
    
    // Second contains the `MyTestClass` definition
    object command2 {
      val wrapper = new command2
      val instance = new command.Helper
    }
    class command2 extends Serializable {
      _root_.scala.transient private val __amm_usedThings = 
_root_.ammonite.repl.ReplBridge.value.usedEarlierDefinitions.iterator.toSet
    
      private val `cmd1`: cmd1.instance.type = if 
(__amm_usedThings("""cmd1""")) cmd1 else null.asInstanceOf[cmd1.instance.type]
    
      class Helper extends Serializable {
        case class MyTestClass(value: Int)
      }
    }
    ```
    In order to create an instance of `MyTestClass` we need an instance of the 
`Helper`. When an instance of the class is created by Spark itself we use 
`OuterScopes` that - for Ammonite generated classes - accesses the command 
object to fetch the helper instance. The problem with this, is that the access 
triggers the creation of an instance of the command, when you create an 
instance of the command this tries to access the REPL to figure out which one 
of its dependents is in use (clever  [...]
    
    This PR fixes this issue by explicitly passing an getter for the outer 
instance to the `ProductEncoder`. For ammonite we actually ship the helper 
instance.  This way the encoder always carries the information it needs to 
create the class.
    
    ### Why are the changes needed?
    This fixes a bug when you try to use a REPL defined class as the input of 
the UDF. For example this will work now:
    ```scala
    val filePath = "my_path" // we need some previous cell that exposes a 
symbol that could be captured in the class definition.
    case class MyTestClass(value: Int) {
          override def toString: String = value.toString
    }
    
spark.range(10).select(col("id").cast("int").as("value")).as[MyTestClass].map(mtc
 => mtc.value).collect()
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added a test to `ReplE2ESuite` illustrate the issue.
    
    Closes #42489 from hvanhovell/SPARK-44799.
    
    Authored-by: Herman van Hovell <her...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  3 +-
 .../spark/sql/connect/client/SparkResult.scala     |  4 +--
 .../connect/client/arrow/ArrowDeserializer.scala   |  6 ++--
 .../sql/connect/client/arrow/ArrowSerializer.scala |  2 +-
 .../spark/sql/application/ReplE2ESuite.scala       | 13 +++++++++
 .../spark/sql/catalyst/ScalaReflection.scala       |  5 ++--
 .../sql/catalyst/encoders/AgnosticEncoder.scala    |  5 ++--
 .../spark/sql/catalyst/encoders/OuterScopes.scala  | 33 +++++++++++++++++-----
 .../sql/catalyst/DeserializerBuildHelper.scala     |  4 +--
 .../spark/sql/catalyst/SerializerBuildHelper.scala |  2 +-
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  |  3 +-
 11 files changed, 58 insertions(+), 22 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 28b04fb850e..cb7d2c84df5 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -883,7 +883,8 @@ class Dataset[T] private[sql] (
         
ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")),
         Seq(
           EncoderField(s"_1", this.agnosticEncoder, leftNullable, 
Metadata.empty),
-          EncoderField(s"_2", other.agnosticEncoder, rightNullable, 
Metadata.empty)))
+          EncoderField(s"_2", other.agnosticEncoder, rightNullable, 
Metadata.empty)),
+        None)
 
     sparkSession.newDataset(tupleEncoder) { builder =>
       val joinBuilder = builder.getJoinBuilder
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
index 609e84779fb..48278311428 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -60,14 +60,14 @@ private[sql] class SparkResult[T](
         RowEncoder
           .encoderFor(dataType.asInstanceOf[StructType])
           .asInstanceOf[AgnosticEncoder[E]]
-      case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) =>
+      case ProductEncoder(clsTag, fields, outer) if 
ProductEncoder.isTuple(clsTag) =>
         // Recursively continue updating the tuple product encoder
         val schema = dataType.asInstanceOf[StructType]
         assert(fields.length <= schema.fields.length)
         val updatedFields = fields.zipWithIndex.map { case (f, id) =>
           f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType))
         }
-        ProductEncoder(clsTag, updatedFields)
+        ProductEncoder(clsTag, updatedFields, outer)
       case _ =>
         enc
     }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
index 82086b9d47a..cd54966ccf5 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala
@@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader
 import org.apache.arrow.vector.util.Text
 
 import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
 import org.apache.spark.sql.connect.client.CloseableIterator
@@ -288,9 +288,9 @@ object ArrowDeserializers {
           throw unsupportedCollectionType(tag.runtimeClass)
         }
 
-      case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
+      case (ProductEncoder(tag, fields, outerPointerGetter), 
StructVectors(struct, vectors)) =>
+        val outer = outerPointerGetter.map(_()).toSeq
         // We should try to make this work with MethodHandles.
-        val outer = 
Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq
         val Some(constructor) =
           ScalaReflection.findConstructor(
             tag.runtimeClass,
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
index 9e67522711c..4c14489947f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala
@@ -413,7 +413,7 @@ object ArrowSerializer {
               serializerFor(value, 
structVector.getChild(MapVector.VALUE_NAME))) :: Nil)
         new ArraySerializer(v, extractor, structSerializer)
 
-      case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) =>
+      case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) =>
         if (isSubClass(classOf[Product], tag)) {
           structSerializerFor(fields, struct, vectors) { (_, i) => p =>
             p.asInstanceOf[Product].productElement(i)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 6b31b5e923d..b2971236147 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -283,6 +283,19 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output)
   }
 
+  test("REPL class in encoder") {
+    val input = """
+        |case class MyTestClass(value: Int)
+        |spark.range(3).
+        |  select(col("id").cast("int").as("value")).
+        |  as[MyTestClass].
+        |  map(mtc => mtc.value).
+        |  collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[Int] = Array(0, 1, 2)", output)
+  }
+
   test("REPL class in UDF") {
     val input = """
         |case class MyTestClass(value: Int)
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5f063b4a9a6..24317c73d85 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -28,7 +28,7 @@ import org.apache.commons.lang3.reflect.ConstructorUtils
 
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
+import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes}
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
 import org.apache.spark.sql.errors.ExecutionErrors
 import org.apache.spark.sql.types._
@@ -394,7 +394,8 @@ object ScalaReflection extends ScalaReflection {
               isRowEncoderSupported)
             EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty)
         }
-        ProductEncoder(ClassTag(getClassFromType(t)), params)
+        val cls = getClassFromType(t)
+        ProductEncoder(ClassTag(cls), params, 
Option(OuterScopes.getOuterScope(cls)))
       case _ =>
         throw ExecutionErrors.cannotFindEncoderForTypeError(tpe.toString)
     }
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 99f33214d20..e5e9ba644b8 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -113,7 +113,8 @@ object AgnosticEncoders {
   // This supports both Product and DefinedByConstructorParams
   case class ProductEncoder[K](
       override val clsTag: ClassTag[K],
-      override val fields: Seq[EncoderField]) extends StructEncoder[K]
+      override val fields: Seq[EncoderField],
+      outerPointerGetter: Option[() => AnyRef]) extends StructEncoder[K]
 
   object ProductEncoder {
     val cachedCls = new ConcurrentHashMap[Int, Class[_]]
@@ -123,7 +124,7 @@ object AgnosticEncoders {
       }
       val cls = cachedCls.computeIfAbsent(encoders.size,
         _ => 
SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}"))
-      ProductEncoder[Any](ClassTag(cls), fields)
+      ProductEncoder[Any](ClassTag(cls), fields, None)
     }
 
     private[sql] def isTuple(tag: ClassTag[_]): Boolean = {
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index 6c10e8ece80..b497cd3f386 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -82,13 +82,32 @@ object OuterScopes {
     if (outer == null) {
       outerClassName match {
         case AmmoniteREPLClass(cellClassName) =>
-          () => {
-            val objClass = SparkClassUtils.classForName(cellClassName)
-            val objInstance = objClass.getField("MODULE$").get(null)
-            val obj = objClass.getMethod("instance").invoke(objInstance)
-            addOuterScope(obj)
-            obj
-          }
+          /* A short introduction to Ammonite class generation.
+           *
+           * There are three classes generated for each command:
+           * - The command. This contains all the dependencies needed to 
execute the command. It
+           *   also contains some logic to only initialize dependencies it 
needs, the others will
+           *   be null. This logic is powered by the compiler, and it will 
only work when there is
+           *   an Ammonite REPL bound through the ReplBridge; it will fail 
with a
+           *   NullPointerException when this is not the case.
+           * - The Helper. This contains the user code. This is an inner class 
of the command. If
+           *   it needs one of its dependencies it will pull them from the 
command. The helper
+           *   instance is needed when a class is defined in the user code. 
This where this code
+           *   comes in, it resolves the Helper instance.
+           * - The command companion object. This holds an instance of the 
Helper class and the
+           *   command. When you touch the command companion on a machine 
where the REPL is not
+           *   running (driver and executors for connect), and the command has 
dependencies, the
+           *   initialization of the command will fail because it cannot use 
the REPL to figure out
+           *   which dependencies to retain.
+           *
+           * To by-pass the problem with executor side helper resolution, we 
eagerly capture the
+           * helper instance here.
+           */
+          val objClass = SparkClassUtils.classForName(cellClassName)
+          val objInstance = objClass.getField("MODULE$").get(null)
+          val obj = objClass.getMethod("instance").invoke(objInstance)
+          addOuterScope(obj)
+          () => obj
         // If the outer class is generated by REPL, users don't need to 
register it as it has
         // only one instance and there is a way to retrieve it: get the 
`$read` object, call the
         // `INSTANCE()` method to get the single instance of class `$read`. 
Then call `$iw()`
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index bdf996424ad..16a7d7ff065 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -350,7 +350,7 @@ object DeserializerBuildHelper {
         createDeserializer(valueEncoder, _, newTypePath),
         tag.runtimeClass)
 
-    case ProductEncoder(tag, fields) =>
+    case ProductEncoder(tag, fields, outerPointerGetter) =>
       val cls = tag.runtimeClass
       val dt = ObjectType(cls)
       val isTuple = cls.getName.startsWith("scala.Tuple")
@@ -373,7 +373,7 @@ object DeserializerBuildHelper {
       exprs.If(
         IsNull(path),
         exprs.Literal.create(null, dt),
-        NewInstance(cls, arguments, dt, propagateNull = false))
+        NewInstance(cls, arguments, Nil, propagateNull = false, dt, 
outerPointerGetter))
 
     case AgnosticEncoders.RowEncoder(fields) =>
       val convertedFields = fields.zipWithIndex.map { case (f, i) =>
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index 7a4061a4b56..27090ff6fa5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -347,7 +347,7 @@ object SerializerBuildHelper {
           validateAndSerializeElement(valueEncoder, valueContainsNull))
       )
 
-    case ProductEncoder(_, fields) =>
+    case ProductEncoder(_, fields, _) =>
       val serializedFields = fields.map { field =>
         // SPARK-26730 inputObject won't be null with If's guard below. And 
KnownNotNul
         // is necessary here. Because for a nullable nested inputObject with 
struct data
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 690e55dbe5f..bbb62acd025 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -611,7 +611,8 @@ class ScalaReflectionSuite extends SparkFunSuite {
     assert(encoderForWithRowEncoderSupport[MyClass] ===
       ProductEncoder(
         ClassTag(getClassFromType(typeTag[MyClass].tpe)),
-        Seq(EncoderField("row", UnboundRowEncoder, true, Metadata.empty))))
+        Seq(EncoderField("row", UnboundRowEncoder, true, Metadata.empty)),
+        None))
   }
 
   case class MyClass(row: Row)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to