This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6d0fed9a18f [SPARK-43744][CONNECT] Fix class loading problem caused by 
stub user classes not found on the server classpath
6d0fed9a18f is described below

commit 6d0fed9a18ff87e73fdf1ee46b6b0d2df8dd5a1b
Author: Zhen Li <zhenli...@users.noreply.github.com>
AuthorDate: Fri Jul 28 22:59:07 2023 -0400

    [SPARK-43744][CONNECT] Fix class loading problem caused by stub user 
classes not found on the server classpath
    
    ### What changes were proposed in this pull request?
    This PR introduces a stub class loader for unpacking Scala UDFs in the 
driver and the executor. When encountering user classes that are not found on 
the server session classpath, the stub class loader would try to stub the class.
    
    This solves the problem that when serializing UDFs, Java serializer might 
include unnecessary user code e.g. User classes used in the lambda definition 
signatures in the same class where the UDF is defined.
    
    If the user code is actually needed to execute the UDF, we will return an 
error message to suggest the user to add the missing classes using the 
`addArtifact` method.
    
    ### Why are the changes needed?
    To enhance the user experience of UDF. This PR should be merged to master 
and 3.5.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Added test both for Scala 2.12 & 2.13
    
    4 tests in SparkSessionE2ESuite still fail to run with maven after the fix 
because the client test jar is installed on the system classpath (added using 
--jar at server start), the stub classloader can only stub classes missing from 
the session classpath (added using `session.addArtifact`).
    
    Moving the test jar to the session classpath causes failures in tests for 
`flatMapGroupsWithState` (SPARK-44576). Finish moving the test jar to session 
classpath once `flatMapGroupsWithState` test failures are fixed.
    
    Closes #42069 from zhenlineo/ref-spark-result.
    
    Authored-by: Zhen Li <zhenli...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../scala/org/apache/spark/sql/SparkSession.scala  |   2 +-
 .../sql/expressions/UserDefinedFunction.scala      |   2 +-
 .../jvm/src/test/resources/StubClassDummyUdf.scala |  56 +++++++++
 .../connect/client/jvm/src/test/resources/udf2.12  | Bin 0 -> 1520 bytes
 .../client/jvm/src/test/resources/udf2.12.jar      | Bin 0 -> 5332 bytes
 .../connect/client/jvm/src/test/resources/udf2.13  | Bin 0 -> 1630 bytes
 .../client/jvm/src/test/resources/udf2.13.jar      | Bin 0 -> 5674 bytes
 .../connect/client/UDFClassLoadingE2ESuite.scala   |  83 +++++++++++++
 .../connect/client/util/IntegrationTestUtils.scala |   2 +-
 .../connect/client/util/RemoteSparkSession.scala   |   2 +-
 .../artifact/SparkConnectArtifactManager.scala     |  17 ++-
 .../sql/connect/planner/SparkConnectPlanner.scala  |  23 +++-
 connector/connect/server/src/test/resources/udf    | Bin 0 -> 973 bytes
 .../connect/server/src/test/resources/udf_noA.jar  | Bin 0 -> 5545 bytes
 .../connect/artifact/StubClassLoaderSuite.scala    | 132 +++++++++++++++++++++
 .../spark/util/ChildFirstURLClassLoader.java       |   9 ++
 .../scala/org/apache/spark/executor/Executor.scala |  86 +++++++++++---
 .../org/apache/spark/internal/config/package.scala |  14 +++
 .../org/apache/spark/util/StubClassLoader.scala    |  79 ++++++++++++
 19 files changed, 480 insertions(+), 27 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index d1832e65f3e..4b3de91b56f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -554,7 +554,7 @@ class SparkSession private[sql] (
     val command = proto.Command.newBuilder().setRegisterFunction(udf).build()
     val plan = proto.Plan.newBuilder().setCommand(command).build()
 
-    client.execute(plan)
+    client.execute(plan).asScala.foreach(_ => ())
   }
 
   @DeveloperApi
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 18aef8a2e4c..e5c89d90c19 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -92,7 +92,7 @@ sealed abstract class UserDefinedFunction {
 /**
  * Holder class for a scalar user-defined function and it's input/output 
encoder(s).
  */
-case class ScalarUserDefinedFunction private (
+case class ScalarUserDefinedFunction private[sql] (
     // SPARK-43198: Eagerly serialize to prevent the UDF from containing a 
reference to this class.
     serializedUdfPacket: Array[Byte],
     inputTypes: Seq[proto.DataType],
diff --git 
a/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala 
b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala
new file mode 100644
index 00000000000..ff1b3deafaf
--- /dev/null
+++ b/connector/connect/client/jvm/src/test/resources/StubClassDummyUdf.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+// To generate a jar from the source file:
+// `scalac StubClassDummyUdf.scala -d udf.jar`
+// To remove class A from the jar:
+// `jar -xvf udf.jar` -> delete A.class and A$.class
+// `jar -cvf udf_noA.jar org/`
+class StubClassDummyUdf {
+  val udf: Int => Int = (x: Int) => x + 1
+  val dummy = (x: Int) => A(x)
+}
+
+case class A(x: Int) { def get: Int = x + 5 }
+
+// The code to generate the udf file
+object StubClassDummyUdf {
+  import java.io.{BufferedOutputStream, File, FileOutputStream}
+  import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.PrimitiveIntEncoder
+  import org.apache.spark.sql.connect.common.UdfPacket
+  import org.apache.spark.util.Utils
+
+  def packDummyUdf(): String = {
+    val byteArray =
+      Utils.serialize[UdfPacket](
+        new UdfPacket(
+          new StubClassDummyUdf().udf,
+          Seq(PrimitiveIntEncoder),
+          PrimitiveIntEncoder
+        )
+      )
+    val file = new File("src/test/resources/udf")
+    val target = new BufferedOutputStream(new FileOutputStream(file))
+    try {
+      target.write(byteArray)
+      file.getAbsolutePath
+    } finally {
+      target.close
+    }
+  }
+}
diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12 
b/connector/connect/client/jvm/src/test/resources/udf2.12
new file mode 100644
index 00000000000..1090bc90d9b
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/udf2.12 differ
diff --git a/connector/connect/client/jvm/src/test/resources/udf2.12.jar 
b/connector/connect/client/jvm/src/test/resources/udf2.12.jar
new file mode 100644
index 00000000000..6ce6799678f
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/udf2.12.jar differ
diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13 
b/connector/connect/client/jvm/src/test/resources/udf2.13
new file mode 100644
index 00000000000..863ac32a76d
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/udf2.13 differ
diff --git a/connector/connect/client/jvm/src/test/resources/udf2.13.jar 
b/connector/connect/client/jvm/src/test/resources/udf2.13.jar
new file mode 100644
index 00000000000..c89830f127c
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/udf2.13.jar differ
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala
new file mode 100644
index 00000000000..8fdb7efbcba
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/UDFClassLoadingE2ESuite.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+
+import scala.util.Properties
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.connect.common.ProtoDataTypes
+import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
+
+class UDFClassLoadingE2ESuite extends RemoteSparkSession {
+
+  private val scalaVersion = Properties.versionNumberString
+    .split("\\.")
+    .take(2)
+    .mkString(".")
+
+  // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are 
created.
+  private val udfByteArray: Array[Byte] =
+    Files.readAllBytes(Paths.get(s"src/test/resources/udf$scalaVersion"))
+  private val udfJar =
+    new File(s"src/test/resources/udf$scalaVersion.jar").toURI.toURL
+
+  private def registerUdf(session: SparkSession): Unit = {
+    val udf = ScalarUserDefinedFunction(
+      serializedUdfPacket = udfByteArray,
+      inputTypes = Seq(ProtoDataTypes.IntegerType),
+      outputType = ProtoDataTypes.IntegerType,
+      name = Some("dummyUdf"),
+      nullable = true,
+      deterministic = true)
+    session.registerUdf(udf.toProto)
+  }
+
+  test("update class loader after stubbing: new session") {
+    // Session1 should stub the missing class, but fail to call methods on it
+    val session1 = spark.newSession()
+
+    assert(
+      intercept[Exception] {
+        registerUdf(session1)
+      }.getMessage.contains(
+        "java.lang.NoSuchMethodException: 
org.apache.spark.sql.connect.client.StubClassDummyUdf"))
+
+    // Session2 uses the real class
+    val session2 = spark.newSession()
+    session2.addArtifact(udfJar.toURI)
+    registerUdf(session2)
+  }
+
+  test("update class loader after stubbing: same session") {
+    // Session should stub the missing class, but fail to call methods on it
+    val session = spark.newSession()
+
+    assert(
+      intercept[Exception] {
+        registerUdf(session)
+      }.getMessage.contains(
+        "java.lang.NoSuchMethodException: 
org.apache.spark.sql.connect.client.StubClassDummyUdf"))
+
+    // Session uses the real class
+    session.addArtifact(udfJar.toURI)
+    registerUdf(session)
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
index 819df5fc25b..4d88565308f 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/IntegrationTestUtils.scala
@@ -30,7 +30,7 @@ object IntegrationTestUtils {
 
   // System properties used for testing and debugging
   private val DEBUG_SC_JVM_CLIENT = "spark.debug.sc.jvm.client"
-  // Enable this flag to print all client debug log + server logs to the 
console
+  // Enable this flag to print all server logs to the console
   private[connect] val isDebug = System.getProperty(DEBUG_SC_JVM_CLIENT, 
"false").toBoolean
 
   private[sql] lazy val scalaVersion = {
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 594d3c369fe..1c1cb1403fe 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -96,7 +96,7 @@ object SparkConnectServerUtils {
     // To find InMemoryTableCatalog for V2 writer tests
     val catalystTestJar =
       tryFindJar("sql/catalyst", "spark-catalyst", "spark-catalyst", test = 
true)
-        .map(clientTestJar => Seq("--jars", clientTestJar.getCanonicalPath))
+        .map(clientTestJar => Seq(clientTestJar.getCanonicalPath))
         .getOrElse(Seq.empty)
 
     // For UDF maven E2E tests, the server needs the client code to find the 
UDFs defined in tests.
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index d8f290639c2..03391cef68b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -31,12 +31,13 @@ import org.apache.hadoop.fs.{LocalFileSystem, Path => 
FSPath}
 
 import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext, 
SparkEnv}
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.CONNECT_SCALA_UDF_STUB_CLASSES
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
 import 
org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
 import org.apache.spark.sql.connect.service.SessionHolder
 import org.apache.spark.storage.{CacheId, StorageLevel}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}
 
 /**
  * The Artifact Manager for the [[SparkConnectService]].
@@ -161,7 +162,19 @@ class SparkConnectArtifactManager(sessionHolder: 
SessionHolder) extends Logging
    */
   def classloader: ClassLoader = {
     val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
-    new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+    val loader = if 
(SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) {
+      val stubClassLoader =
+        StubClassLoader(null, 
SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
+      new ChildFirstURLClassLoader(
+        urls.toArray,
+        stubClassLoader,
+        Utils.getContextOrSparkClassLoader)
+    } else {
+      new URLClassLoader(urls.toArray, Utils.getContextOrSparkClassLoader)
+    }
+
+    logDebug(s"Using class loader: $loader, containing urls: $urls")
+    loader
   }
 
   /**
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e4ac34715fb..ebed8af48f0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connect.planner
 
+import java.io.IOException
+
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 import scala.util.Try
@@ -1504,15 +1506,24 @@ class SparkConnectPlanner(val sessionHolder: 
SessionHolder) extends Logging {
   }
 
   private def unpackUdf(fun: proto.CommonInlineUserDefinedFunction): UdfPacket 
= {
-    Utils.deserialize[UdfPacket](
-      fun.getScalarScalaUdf.getPayload.toByteArray,
-      Utils.getContextOrSparkClassLoader)
+    unpackScalarScalaUDF[UdfPacket](fun.getScalarScalaUdf)
   }
 
   private def unpackForeachWriter(fun: proto.ScalarScalaUDF): 
ForeachWriterPacket = {
-    Utils.deserialize[ForeachWriterPacket](
-      fun.getPayload.toByteArray,
-      Utils.getContextOrSparkClassLoader)
+    unpackScalarScalaUDF[ForeachWriterPacket](fun)
+  }
+
+  private def unpackScalarScalaUDF[T](fun: proto.ScalarScalaUDF): T = {
+    try {
+      logDebug(s"Unpack using class loader: 
${Utils.getContextOrSparkClassLoader}")
+      Utils.deserialize[T](fun.getPayload.toByteArray, 
Utils.getContextOrSparkClassLoader)
+    } catch {
+      case e: IOException if e.getCause.isInstanceOf[NoSuchMethodException] =>
+        throw new ClassNotFoundException(
+          s"Failed to load class correctly due to ${e.getCause}. " +
+            "Make sure the artifact where the class is defined is installed by 
calling" +
+            " session.addArtifact.")
+    }
   }
 
   /**
diff --git a/connector/connect/server/src/test/resources/udf 
b/connector/connect/server/src/test/resources/udf
new file mode 100644
index 00000000000..55a3264a017
Binary files /dev/null and b/connector/connect/server/src/test/resources/udf 
differ
diff --git a/connector/connect/server/src/test/resources/udf_noA.jar 
b/connector/connect/server/src/test/resources/udf_noA.jar
new file mode 100644
index 00000000000..4d8c423ab6d
Binary files /dev/null and 
b/connector/connect/server/src/test/resources/udf_noA.jar differ
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
new file mode 100644
index 00000000000..0f6e0543151
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.artifact
+
+import java.io.File
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader}
+
+class StubClassLoaderSuite extends SparkFunSuite {
+
+  // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are 
created.
+  private val udfNoAJar = new 
File("src/test/resources/udf_noA.jar").toURI.toURL
+  private val classDummyUdf = 
"org.apache.spark.sql.connect.client.StubClassDummyUdf"
+  private val classA = "org.apache.spark.sql.connect.client.A"
+
+  test("find class with stub class") {
+    val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => 
true)
+    val cls = cl.findClass("my.name.HelloWorld")
+    assert(cls.getName === "my.name.HelloWorld")
+    assert(cl.lastStubbed === "my.name.HelloWorld")
+  }
+
+  test("class for name with stub class") {
+    val cl = new RecordedStubClassLoader(getClass().getClassLoader(), _ => 
true)
+    // scalastyle:off classforname
+    val cls = Class.forName("my.name.HelloWorld", false, cl)
+    // scalastyle:on classforname
+    assert(cls.getName === "my.name.HelloWorld")
+    assert(cl.lastStubbed === "my.name.HelloWorld")
+  }
+
+  test("filter class to stub") {
+    val list = "my.name" :: Nil
+    val cl = StubClassLoader(getClass().getClassLoader(), list)
+    val cls = cl.findClass("my.name.HelloWorld")
+    assert(cls.getName === "my.name.HelloWorld")
+
+    intercept[ClassNotFoundException] {
+      cl.findClass("name.my.GoodDay")
+    }
+  }
+
+  test("stub missing class") {
+    val sysClassLoader = getClass.getClassLoader()
+    val stubClassLoader = new RecordedStubClassLoader(null, _ => true)
+
+    // Install artifact without class A.
+    val sessionClassLoader =
+      new ChildFirstURLClassLoader(Array(udfNoAJar), stubClassLoader, 
sysClassLoader)
+    // Load udf with A used in the same class.
+    loadDummyUdf(sessionClassLoader)
+    // Class A should be stubbed.
+    assert(stubClassLoader.lastStubbed === classA)
+  }
+
+  test("unload stub class") {
+    val sysClassLoader = getClass.getClassLoader()
+    val stubClassLoader = new RecordedStubClassLoader(null, _ => true)
+
+    val cl1 = new ChildFirstURLClassLoader(Array.empty, stubClassLoader, 
sysClassLoader)
+
+    // Failed to load DummyUdf
+    intercept[Exception] {
+      loadDummyUdf(cl1)
+    }
+    // Successfully stubbed the missing class.
+    assert(stubClassLoader.lastStubbed === classDummyUdf)
+
+    // Creating a new class loader will unpack the udf correctly.
+    val cl2 = new ChildFirstURLClassLoader(
+      Array(udfNoAJar),
+      stubClassLoader, // even with the same stub class loader.
+      sysClassLoader)
+    // Should be able to load after the artifact is added
+    loadDummyUdf(cl2)
+  }
+
+  test("throw no such method if trying to access methods on stub class") {
+    val sysClassLoader = getClass.getClassLoader()
+    val stubClassLoader = new RecordedStubClassLoader(null, _ => true)
+
+    val sessionClassLoader =
+      new ChildFirstURLClassLoader(Array.empty, stubClassLoader, 
sysClassLoader)
+
+    // Failed to load DummyUdf because of missing methods
+    assert(intercept[NoSuchMethodException] {
+      loadDummyUdf(sessionClassLoader)
+    }.getMessage.contains(classDummyUdf))
+    // Successfully stubbed the missing class.
+    assert(stubClassLoader.lastStubbed === classDummyUdf)
+  }
+
+  private def loadDummyUdf(sessionClassLoader: ClassLoader): Unit = {
+    // Load DummyUdf and call a method on it.
+    // scalastyle:off classforname
+    val cls = Class.forName(classDummyUdf, false, sessionClassLoader)
+    // scalastyle:on classforname
+    cls.getDeclaredMethod("dummy")
+
+    // Load class A used inside DummyUdf
+    // scalastyle:off classforname
+    Class.forName(classA, false, sessionClassLoader)
+    // scalastyle:on classforname
+  }
+}
+
+class RecordedStubClassLoader(parent: ClassLoader, shouldStub: String => 
Boolean)
+    extends StubClassLoader(parent, shouldStub) {
+  var lastStubbed: String = _
+
+  override def findClass(name: String): Class[_] = {
+    if (shouldStub(name)) {
+      lastStubbed = name
+    }
+    super.findClass(name)
+  }
+}
diff --git 
a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java 
b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java
index 57d96756c8b..2791209e019 100644
--- a/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java
+++ b/core/src/main/java/org/apache/spark/util/ChildFirstURLClassLoader.java
@@ -40,6 +40,15 @@ public class ChildFirstURLClassLoader extends 
MutableURLClassLoader {
     this.parent = new ParentClassLoader(parent);
   }
 
+  /**
+   * Specify the grandparent if there is a need to load in the order of
+   * `grandparent -&gt; urls (child) -&gt; parent`.
+   */
+  public ChildFirstURLClassLoader(URL[] urls, ClassLoader parent, ClassLoader 
grandparent) {
+    super(urls, grandparent);
+    this.parent = new ParentClassLoader(parent);
+  }
+
   @Override
   public Class<?> loadClass(String name, boolean resolve) throws 
ClassNotFoundException {
     try {
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala 
b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index b30569dc964..9327ea4d3dd 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -56,11 +56,12 @@ import org.apache.spark.util._
 
 private[spark] class IsolatedSessionState(
   val sessionUUID: String,
-  val urlClassLoader: MutableURLClassLoader,
+  var urlClassLoader: MutableURLClassLoader,
   var replClassLoader: ClassLoader,
   val currentFiles: HashMap[String, Long],
   val currentJars: HashMap[String, Long],
-  val currentArchives: HashMap[String, Long])
+  val currentArchives: HashMap[String, Long],
+  val replClassDirUri: Option[String])
 
 /**
  * Spark executor, backed by a threadpool to run tasks.
@@ -173,14 +174,20 @@ private[spark] class Executor(
     val currentFiles = new HashMap[String, Long]
     val currentJars = new HashMap[String, Long]
     val currentArchives = new HashMap[String, Long]
-    val urlClassLoader = createClassLoader(currentJars)
+    val urlClassLoader = createClassLoader(currentJars, 
!isDefaultState(jobArtifactState.uuid))
     val replClassLoader = addReplClassLoaderIfNeeded(
-      urlClassLoader, jobArtifactState.replClassDirUri)
+      urlClassLoader, jobArtifactState.replClassDirUri, jobArtifactState.uuid)
     new IsolatedSessionState(
       jobArtifactState.uuid, urlClassLoader, replClassLoader,
-      currentFiles, currentJars, currentArchives)
+      currentFiles,
+      currentJars,
+      currentArchives,
+      jobArtifactState.replClassDirUri
+    )
   }
 
+  private def isDefaultState(name: String) = name == "default"
+
   // Classloader isolation
   // The default isolation group
   val defaultSessionState = newSessionState(JobArtifactState("default", None))
@@ -514,9 +521,8 @@ private[spark] class Executor(
 
       // Classloader isolation
       val isolatedSession = taskDescription.artifacts.state match {
-        case Some(jobArtifactState) => isolatedSessionCache.get(
-          jobArtifactState.uuid,
-          () => newSessionState(jobArtifactState))
+        case Some(jobArtifactState) =>
+          isolatedSessionCache.get(jobArtifactState.uuid, () => 
newSessionState(jobArtifactState))
         case _ => defaultSessionState
       }
 
@@ -548,6 +554,9 @@ private[spark] class Executor(
           taskDescription.artifacts.jars,
           taskDescription.artifacts.archives,
           isolatedSession)
+        // Always reset the thread class loader to ensure if any updates, all 
threads (not only
+        // the thread that updated the dependencies) can update to the new 
class loader.
+        
Thread.currentThread.setContextClassLoader(isolatedSession.replClassLoader)
         task = ser.deserialize[Task[Any]](
           taskDescription.serializedTask, 
Thread.currentThread.getContextClassLoader)
         task.localProperties = taskDescription.properties
@@ -999,7 +1008,9 @@ private[spark] class Executor(
    * Create a ClassLoader for use in tasks, adding any JARs specified by the 
user or any classes
    * created by the interpreter to the search path
    */
-  private def createClassLoader(currentJars: HashMap[String, Long]): 
MutableURLClassLoader = {
+  private def createClassLoader(
+      currentJars: HashMap[String, Long],
+      useStub: Boolean): MutableURLClassLoader = {
     // Bootstrap the list of jars with the user class path.
     val now = System.currentTimeMillis()
     userClassPath.foreach { url =>
@@ -1011,8 +1022,23 @@ private[spark] class Executor(
     val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
       new File(uri.split("/").last).toURI.toURL
     }
-    logInfo(s"Starting executor with user classpath (userClassPathFirst = 
$userClassPathFirst): " +
-        urls.mkString("'", ",", "'"))
+    createClassLoader(urls, useStub)
+  }
+
+  private def createClassLoader(urls: Array[URL], useStub: Boolean): 
MutableURLClassLoader = {
+    logInfo(
+      s"Starting executor with user classpath (userClassPathFirst = 
$userClassPathFirst): " +
+      urls.mkString("'", ",", "'")
+    )
+
+    if (useStub && conf.get(CONNECT_SCALA_UDF_STUB_CLASSES).nonEmpty) {
+      createClassLoaderWithStub(urls, conf.get(CONNECT_SCALA_UDF_STUB_CLASSES))
+    } else {
+      createClassLoader(urls)
+    }
+  }
+
+  private def createClassLoader(urls: Array[URL]): MutableURLClassLoader = {
     if (userClassPathFirst) {
       new ChildFirstURLClassLoader(urls, systemLoader)
     } else {
@@ -1020,20 +1046,39 @@ private[spark] class Executor(
     }
   }
 
+  private def createClassLoaderWithStub(
+      urls: Array[URL],
+      binaryName: Seq[String]): MutableURLClassLoader = {
+    if (userClassPathFirst) {
+      // user -> (sys -> stub)
+      val stubClassLoader =
+        StubClassLoader(systemLoader, binaryName)
+      new ChildFirstURLClassLoader(urls, stubClassLoader)
+    } else {
+      // sys -> user -> stub
+      val stubClassLoader =
+        StubClassLoader(null, binaryName)
+      new ChildFirstURLClassLoader(urls, stubClassLoader, systemLoader)
+    }
+  }
+
   /**
    * If the REPL is in use, add another ClassLoader that will read
    * new classes defined by the REPL as the user types code
    */
   private def addReplClassLoaderIfNeeded(
       parent: ClassLoader,
-      sessionClassUri: Option[String]): ClassLoader = {
+      sessionClassUri: Option[String],
+      sessionUUID: String): ClassLoader = {
     val classUri = sessionClassUri.getOrElse(conf.get("spark.repl.class.uri", 
null))
-    if (classUri != null) {
+    val classLoader = if (classUri != null) {
       logInfo("Using REPL class URI: " + classUri)
       new ExecutorClassLoader(conf, env, classUri, parent, userClassPathFirst)
     } else {
       parent
     }
+    logInfo(s"Created or updated repl class loader $classLoader for 
$sessionUUID.")
+    classLoader
   }
 
   /**
@@ -1048,6 +1093,7 @@ private[spark] class Executor(
       state: IsolatedSessionState,
       testStartLatch: Option[CountDownLatch] = None,
       testEndLatch: Option[CountDownLatch] = None): Unit = {
+    var updated = false;
     lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
     updateDependenciesLock.lockInterruptibly()
     try {
@@ -1056,7 +1102,7 @@ private[spark] class Executor(
 
       // If the session ID was specified from SparkSession, it's from a Spark 
Connect client.
       // Specify a dedicated directory for Spark Connect client.
-      lazy val root = if (state.sessionUUID != "default") {
+      lazy val root = if (!isDefaultState(state.sessionUUID)) {
         val newDest = new File(SparkFiles.getRootDirectory(), 
state.sessionUUID)
         newDest.mkdir()
         newDest
@@ -1101,11 +1147,21 @@ private[spark] class Executor(
           // Add it to our class loader
           val url = new File(root, localName).toURI.toURL
           if (!state.urlClassLoader.getURLs().contains(url)) {
-            logInfo(s"Adding $url to class loader")
+            logInfo(s"Adding $url to class loader ${state.sessionUUID}")
             state.urlClassLoader.addURL(url)
+            if (!isDefaultState(state.sessionUUID)) {
+              updated = true
+            }
           }
         }
       }
+      if (updated) {
+        // When a new url is added for non-default class loader, recreate the 
class loader
+        // to ensure all classes are updated.
+        state.urlClassLoader = createClassLoader(state.urlClassLoader.getURLs, 
useStub = true)
+        state.replClassLoader =
+          addReplClassLoaderIfNeeded(state.urlClassLoader, 
state.replClassDirUri, state.sessionUUID)
+      }
       // For testing, so we can simulate a slow file download:
       testEndLatch.foreach(_.await())
     } finally {
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 83e64f6f8a8..ba809b7a3b1 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2555,4 +2555,18 @@ package object config {
       .version("3.5.0")
       .booleanConf
       .createWithDefault(false)
+
+  private[spark] val CONNECT_SCALA_UDF_STUB_CLASSES =
+    ConfigBuilder("spark.connect.scalaUdf.stubClasses")
+      .internal()
+      .doc("""
+          |Comma-separated list of binary names of classes/packages that 
should be stubbed during
+          |the Scala UDF serde and execution if not found on the server 
classpath.
+          |An empty list effectively disables stubbing for all missing classes.
+          |By default, the server stubs classes from the Scala client package.
+          |""".stripMargin)
+      .version("3.5.0")
+      .stringConf
+      .toSequence
+      .createWithDefault("org.apache.spark.sql.connect.client" :: Nil)
 }
diff --git a/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala 
b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
new file mode 100644
index 00000000000..a0bc753f488
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/StubClassLoader.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+import org.apache.xbean.asm9.{ClassWriter, Opcodes}
+
+/**
+ * [[ClassLoader]] that replaces missing classes with stubs, if the cannot be 
found. It will only
+ * do this for classes that are marked for stubbing.
+ *
+ * While this is generally not a good idea. In this particular case this is 
used to load lambda's
+ * whose capturing class contains unknown (and unneeded) classes. The lambda 
itself does not need
+ * the class and therefor is safe to replace by a stub.
+ */
+class StubClassLoader(parent: ClassLoader, shouldStub: String => Boolean)
+  extends ClassLoader(parent) {
+  override def findClass(name: String): Class[_] = {
+    if (!shouldStub(name)) {
+      throw new ClassNotFoundException(name)
+    }
+    val bytes = StubClassLoader.generateStub(name)
+    defineClass(name, bytes, 0, bytes.length)
+  }
+}
+
+object StubClassLoader {
+  def apply(parent: ClassLoader, binaryName: Seq[String]): StubClassLoader = {
+    new StubClassLoader(parent, name => binaryName.exists(p => 
name.startsWith(p)))
+  }
+
+  def generateStub(binaryName: String): Array[Byte] = {
+    // Convert binary names to internal names.
+    val name = binaryName.replace('.', '/')
+    val classWriter = new ClassWriter(0)
+    classWriter.visit(
+      49,
+      Opcodes.ACC_PUBLIC + Opcodes.ACC_SUPER,
+      name,
+      null,
+      "java/lang/Object",
+      null)
+    classWriter.visitSource(name + ".java", null)
+
+    // Generate constructor.
+    val ctorWriter = classWriter.visitMethod(
+      Opcodes.ACC_PUBLIC,
+      "<init>",
+      "()V",
+      null,
+      null)
+    ctorWriter.visitVarInsn(Opcodes.ALOAD, 0)
+    ctorWriter.visitMethodInsn(
+      Opcodes.INVOKESPECIAL,
+      "java/lang/Object",
+      "<init>",
+      "()V",
+      false)
+
+    ctorWriter.visitInsn(Opcodes.RETURN)
+    ctorWriter.visitMaxs(1, 1)
+    ctorWriter.visitEnd()
+    classWriter.visitEnd()
+    classWriter.toByteArray
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to