This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 357bcac823c [SPARK-44388][CONNECT] Fix protobuf cast issue when UDF instance is updated 357bcac823c is described below commit 357bcac823c0fafbdcb95458327d35e4a492046c Author: vicennial <venkata.gud...@databricks.com> AuthorDate: Thu Jul 13 18:55:10 2023 +0900 [SPARK-44388][CONNECT] Fix protobuf cast issue when UDF instance is updated ### What changes were proposed in this pull request? This PR modifies the arguments of the `ScalarUserDefinedFunction` case class to take in `serializedUdfPacket`, `inputTypes` and `outputType` instead of calculating these types internally from `function`, `inputEncoders` and `outputEncoder`. This allows the class to copy over the serialized udf value from the parent instance instead of re-creating it. Through this, we avoid hitting a variant of the issue mentioned in [SPARK-43198](https://issues.apache.org/jira/browse/SPARK-43198) whic [...] ### Why are the changes needed? Bugfix. Consider the following code: ``` class A(x: Int) { def get = x * 7 } val myUdf = udf((x: Int) => new A(x).get) val modifiedUdf = myUdf.withName("myUdf").asNondeterministic() spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect() ``` Executing this code currently results in hitting the following error: ``` java.lang.ClassCastException: org.apache.spark.connect.proto.ScalarScalaUDF cannot be cast to com.google.protobuf.MessageLite at com.google.protobuf.GeneratedMessageLite$SerializedForm.readResolve(GeneratedMessageLite.java:1462) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at java.io.ObjectStreamClass.invokeReadResolve(ObjectStreamClass.java:1274) at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:2196) ... ``` If we do not include the `myUdf.withName("myUdf").asNondeterministic()`, the UDF runs as expected. ### Does this PR introduce _any_ user-facing change? Yes, fixes the bug mentioned above. ### How was this patch tested? Added a new E2E test in `ReplE2ESuite`. Closes #41959 from vicennial/SPARK-44388. Lead-authored-by: vicennial <venkata.gud...@databricks.com> Co-authored-by: Venkata Sai Akhil Gudesa <venkata.gud...@databricks.com> Co-authored-by: Venkata Sai Akhil Gudesa <gvs.akhil1...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/expressions/UserDefinedFunction.scala | 29 +++++++++++----------- .../spark/sql/application/ReplE2ESuite.scala | 11 ++++++++ 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index d911b7efe29..7bce4b5b31a 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -92,26 +92,23 @@ sealed abstract class UserDefinedFunction { /** * Holder class for a scalar user-defined function and it's input/output encoder(s). */ -case class ScalarUserDefinedFunction( - function: AnyRef, - inputEncoders: Seq[AgnosticEncoder[_]], - outputEncoder: AgnosticEncoder[_], +case class ScalarUserDefinedFunction private ( + // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. + serializedUdfPacket: Array[Byte], + inputTypes: Seq[proto.DataType], + outputType: proto.DataType, name: Option[String], override val nullable: Boolean, override val deterministic: Boolean) extends UserDefinedFunction { - // SPARK-43198: Eagerly serialize to prevent the UDF from containing a reference to this class. - private[this] val udf = { - val udfPacketBytes = - SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) + private[this] lazy val udf = { val scalaUdfBuilder = proto.ScalarScalaUDF .newBuilder() - .setPayload(ByteString.copyFrom(udfPacketBytes)) + .setPayload(ByteString.copyFrom(serializedUdfPacket)) // Send the real inputs and return types to obtain the types without deser the udf bytes. - .addAllInputTypes( - inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType).asJava) - .setOutputType(DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType)) + .addAllInputTypes(inputTypes.asJava) + .setOutputType(outputType) .setNullable(nullable) scalaUdfBuilder.build() @@ -154,10 +151,12 @@ object ScalarUserDefinedFunction { function: AnyRef, inputEncoders: Seq[AgnosticEncoder[_]], outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = { + val udfPacketBytes = + SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders, outputEncoder)) ScalarUserDefinedFunction( - function = function, - inputEncoders = inputEncoders, - outputEncoder = outputEncoder, + serializedUdfPacket = udfPacketBytes, + inputTypes = inputEncoders.map(_.dataType).map(DataTypeProtoConverter.toConnectProtoType), + outputType = DataTypeProtoConverter.toConnectProtoType(outputEncoder.dataType), name = None, nullable = true, deterministic = true) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala index 676ad6b090e..40841aa3b39 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala @@ -158,6 +158,17 @@ class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach { assertContains("Array[Int] = Array(5, 47, 89, 131, 173)", output) } + test("Updating UDF properties") { + val input = """ + |class A(x: Int) { def get = x * 7 } + |val myUdf = udf((x: Int) => new A(x).get) + |val modifiedUdf = myUdf.withName("myUdf").asNondeterministic() + |spark.range(5).select(modifiedUdf(col("id"))).as[Int].collect() + """.stripMargin + val output = runCommandsInShell(input) + assertContains("Array[Int] = Array(0, 7, 14, 21, 28)", output) + } + test("SPARK-43198: Filter does not throw ammonite-related class initialization exception") { val input = """ |spark.range(10).filter(n => n % 2 == 0).collect() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org