This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new cb8c653f3c7 [SPARK-44799][CONNECT] Fix outer scopes resolution on the executor side cb8c653f3c7 is described below commit cb8c653f3c7f76129deae366608613a968b81264 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Wed Aug 16 15:28:12 2023 +0200 [SPARK-44799][CONNECT] Fix outer scopes resolution on the executor side ### What changes were proposed in this pull request? When you define a class in the REPL (with previously defines symbols), for example: ```scala val filePath = "my_path" case class MyTestClass(value: Int) ``` This is actually declared inside a command class.In ammonite the structure looks like this: ```scala // First command contains the `filePath` object cmd1 { val wrapper = new cmd1 val instance = new command.Helper } class cmd1 extends Serializable { class Helper extends Serializable { val filePath = "my_path" } } // Second contains the `MyTestClass` definition object command2 { val wrapper = new command2 val instance = new command.Helper } class command2 extends Serializable { _root_.scala.transient private val __amm_usedThings = _root_.ammonite.repl.ReplBridge.value.usedEarlierDefinitions.iterator.toSet private val `cmd1`: cmd1.instance.type = if (__amm_usedThings("""cmd1""")) cmd1 else null.asInstanceOf[cmd1.instance.type] class Helper extends Serializable { case class MyTestClass(value: Int) } } ``` In order to create an instance of `MyTestClass` we need an instance of the `Helper`. When an instance of the class is created by Spark itself we use `OuterScopes` that - for Ammonite generated classes - accesses the command object to fetch the helper instance. The problem with this, is that the access triggers the creation of an instance of the command, when you create an instance of the command this tries to access the REPL to figure out which one of its dependents is in use (clever [...] This PR fixes this issue by explicitly passing an getter for the outer instance to the `ProductEncoder`. For ammonite we actually ship the helper instance. This way the encoder always carries the information it needs to create the class. ### Why are the changes needed? This fixes a bug when you try to use a REPL defined class as the input of the UDF. For example this will work now: ```scala val filePath = "my_path" // we need some previous cell that exposes a symbol that could be captured in the class definition. case class MyTestClass(value: Int) { override def toString: String = value.toString } spark.range(10).select(col("id").cast("int").as("value")).as[MyTestClass].map(mtc => mtc.value).collect() ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added a test to `ReplE2ESuite` illustrate the issue. Closes #42489 from hvanhovell/SPARK-44799. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../main/scala/org/apache/spark/sql/Dataset.scala | 3 +- .../spark/sql/connect/client/SparkResult.scala | 4 +-- .../connect/client/arrow/ArrowDeserializer.scala | 6 ++-- .../sql/connect/client/arrow/ArrowSerializer.scala | 2 +- .../spark/sql/application/ReplE2ESuite.scala | 13 +++++++++ .../spark/sql/catalyst/ScalaReflection.scala | 5 ++-- .../sql/catalyst/encoders/AgnosticEncoder.scala | 5 ++-- .../spark/sql/catalyst/encoders/OuterScopes.scala | 33 +++++++++++++++++----- .../sql/catalyst/DeserializerBuildHelper.scala | 4 +-- .../spark/sql/catalyst/SerializerBuildHelper.scala | 2 +- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 3 +- 11 files changed, 58 insertions(+), 22 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala index 28b04fb850e..cb7d2c84df5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -883,7 +883,8 @@ class Dataset[T] private[sql] ( ClassTag(SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple2")), Seq( EncoderField(s"_1", this.agnosticEncoder, leftNullable, Metadata.empty), - EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty))) + EncoderField(s"_2", other.agnosticEncoder, rightNullable, Metadata.empty)), + None) sparkSession.newDataset(tupleEncoder) { builder => val joinBuilder = builder.getJoinBuilder diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index 609e84779fb..48278311428 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -60,14 +60,14 @@ private[sql] class SparkResult[T]( RowEncoder .encoderFor(dataType.asInstanceOf[StructType]) .asInstanceOf[AgnosticEncoder[E]] - case ProductEncoder(clsTag, fields) if ProductEncoder.isTuple(clsTag) => + case ProductEncoder(clsTag, fields, outer) if ProductEncoder.isTuple(clsTag) => // Recursively continue updating the tuple product encoder val schema = dataType.asInstanceOf[StructType] assert(fields.length <= schema.fields.length) val updatedFields = fields.zipWithIndex.map { case (f, id) => f.copy(enc = createEncoder(f.enc, schema.fields(id).dataType)) } - ProductEncoder(clsTag, updatedFields) + ProductEncoder(clsTag, updatedFields, outer) case _ => enc } diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index 82086b9d47a..cd54966ccf5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -34,7 +34,7 @@ import org.apache.arrow.vector.ipc.ArrowReader import org.apache.arrow.vector.util.Text import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.connect.client.CloseableIterator @@ -288,9 +288,9 @@ object ArrowDeserializers { throw unsupportedCollectionType(tag.runtimeClass) } - case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => + case (ProductEncoder(tag, fields, outerPointerGetter), StructVectors(struct, vectors)) => + val outer = outerPointerGetter.map(_()).toSeq // We should try to make this work with MethodHandles. - val outer = Option(OuterScopes.getOuterScope(tag.runtimeClass)).map(_()).toSeq val Some(constructor) = ScalaReflection.findConstructor( tag.runtimeClass, diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala index 9e67522711c..4c14489947f 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowSerializer.scala @@ -413,7 +413,7 @@ object ArrowSerializer { serializerFor(value, structVector.getChild(MapVector.VALUE_NAME))) :: Nil) new ArraySerializer(v, extractor, structSerializer) - case (ProductEncoder(tag, fields), StructVectors(struct, vectors)) => + case (ProductEncoder(tag, fields, _), StructVectors(struct, vectors)) => if (isSubClass(classOf[Product], tag)) { structSerializerFor(fields, struct, vectors) { (_, i) => p => p.asInstanceOf[Product].productElement(i) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 6b31b5e923d..b2971236147 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -283,6 +283,19 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("""String = "[MyTestClass(1), MyTestClass(3)]"""", output) } + test("REPL class in encoder") { + val input = """ + |case class MyTestClass(value: Int) + |spark.range(3). + | select(col("id").cast("int").as("value")). + | as[MyTestClass]. + | map(mtc => mtc.value). + | collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Int] = Array(0, 1, 2)", output) + } + test("REPL class in UDF") { val input = """ |case class MyTestClass(value: Int) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5f063b4a9a6..24317c73d85 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -28,7 +28,7 @@ import org.apache.commons.lang3.reflect.ConstructorUtils import org.apache.spark.internal.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder +import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, OuterScopes} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._ import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ @@ -394,7 +394,8 @@ object ScalaReflection extends ScalaReflection { isRowEncoderSupported) EncoderField(fieldName, encoder, encoder.nullable, Metadata.empty) } - ProductEncoder(ClassTag(getClassFromType(t)), params) + val cls = getClassFromType(t) + ProductEncoder(ClassTag(cls), params, Option(OuterScopes.getOuterScope(cls))) case _ => throw ExecutionErrors.cannotFindEncoderForTypeError(tpe.toString) } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 99f33214d20..e5e9ba644b8 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -113,7 +113,8 @@ object AgnosticEncoders { // This supports both Product and DefinedByConstructorParams case class ProductEncoder[K]( override val clsTag: ClassTag[K], - override val fields: Seq[EncoderField]) extends StructEncoder[K] + override val fields: Seq[EncoderField], + outerPointerGetter: Option[() => AnyRef]) extends StructEncoder[K] object ProductEncoder { val cachedCls = new ConcurrentHashMap[Int, Class[_]] @@ -123,7 +124,7 @@ object AgnosticEncoders { } val cls = cachedCls.computeIfAbsent(encoders.size, _ => SparkClassUtils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")) - ProductEncoder[Any](ClassTag(cls), fields) + ProductEncoder[Any](ClassTag(cls), fields, None) } private[sql] def isTuple(tag: ClassTag[_]): Boolean = { diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala index 6c10e8ece80..b497cd3f386 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala @@ -82,13 +82,32 @@ object OuterScopes { if (outer == null) { outerClassName match { case AmmoniteREPLClass(cellClassName) => - () => { - val objClass = SparkClassUtils.classForName(cellClassName) - val objInstance = objClass.getField("MODULE$").get(null) - val obj = objClass.getMethod("instance").invoke(objInstance) - addOuterScope(obj) - obj - } + /* A short introduction to Ammonite class generation. + * + * There are three classes generated for each command: + * - The command. This contains all the dependencies needed to execute the command. It + * also contains some logic to only initialize dependencies it needs, the others will + * be null. This logic is powered by the compiler, and it will only work when there is + * an Ammonite REPL bound through the ReplBridge; it will fail with a + * NullPointerException when this is not the case. + * - The Helper. This contains the user code. This is an inner class of the command. If + * it needs one of its dependencies it will pull them from the command. The helper + * instance is needed when a class is defined in the user code. This where this code + * comes in, it resolves the Helper instance. + * - The command companion object. This holds an instance of the Helper class and the + * command. When you touch the command companion on a machine where the REPL is not + * running (driver and executors for connect), and the command has dependencies, the + * initialization of the command will fail because it cannot use the REPL to figure out + * which dependencies to retain. + * + * To by-pass the problem with executor side helper resolution, we eagerly capture the + * helper instance here. + */ + val objClass = SparkClassUtils.classForName(cellClassName) + val objInstance = objClass.getField("MODULE$").get(null) + val obj = objClass.getMethod("instance").invoke(objInstance) + addOuterScope(obj) + () => obj // If the outer class is generated by REPL, users don't need to register it as it has // only one instance and there is a way to retrieve it: get the `$read` object, call the // `INSTANCE()` method to get the single instance of class `$read`. Then call `$iw()` diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index bdf996424ad..16a7d7ff065 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -350,7 +350,7 @@ object DeserializerBuildHelper { createDeserializer(valueEncoder, _, newTypePath), tag.runtimeClass) - case ProductEncoder(tag, fields) => + case ProductEncoder(tag, fields, outerPointerGetter) => val cls = tag.runtimeClass val dt = ObjectType(cls) val isTuple = cls.getName.startsWith("scala.Tuple") @@ -373,7 +373,7 @@ object DeserializerBuildHelper { exprs.If( IsNull(path), exprs.Literal.create(null, dt), - NewInstance(cls, arguments, dt, propagateNull = false)) + NewInstance(cls, arguments, Nil, propagateNull = false, dt, outerPointerGetter)) case AgnosticEncoders.RowEncoder(fields) => val convertedFields = fields.zipWithIndex.map { case (f, i) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 7a4061a4b56..27090ff6fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -347,7 +347,7 @@ object SerializerBuildHelper { validateAndSerializeElement(valueEncoder, valueContainsNull)) ) - case ProductEncoder(_, fields) => + case ProductEncoder(_, fields, _) => val serializedFields = fields.map { field => // SPARK-26730 inputObject won't be null with If's guard below. And KnownNotNul // is necessary here. Because for a nullable nested inputObject with struct data diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 690e55dbe5f..bbb62acd025 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -611,7 +611,8 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(encoderForWithRowEncoderSupport[MyClass] === ProductEncoder( ClassTag(getClassFromType(typeTag[MyClass].tpe)), - Seq(EncoderField("row", UnboundRowEncoder, true, Metadata.empty)))) + Seq(EncoderField("row", UnboundRowEncoder, true, Metadata.empty)), + None)) } case class MyClass(row: Row) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org