This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 400573d33af [SPARK-45093][CONNECT][PYTHON] Properly support error 
handling and conversion for AddArtifactHandler
400573d33af is described below

commit 400573d33af39d61748dd0b2960c47161277f85e
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Mon Dec 4 08:22:13 2023 +0900

    [SPARK-45093][CONNECT][PYTHON] Properly support error handling and 
conversion for AddArtifactHandler
    
    ### What changes were proposed in this pull request?
    This patch improves the error handling when errors are happening in the 
`AddArtifact` path. In particular, the `AddArtifactHandler` would not properly 
return exceptions but all exceptions would end up yielding `UNKNOWN` errors. 
This patch makes sure we wrap the errors in the add artifact path the same way 
as we're wrapping the errors in the normal query execution path.
    
    In addition, it adds tests and verification for this behavior in Scala and 
in Python.
    
    ### Why are the changes needed?
    Stability
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added UT for Scala and Python
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #44092 from grundprinzip/SPARK-ADD_ARTIFACT_CRAP.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../service/SparkConnectAddArtifactsHandler.scala  | 26 +++++++++++++++++++---
 .../spark/sql/connect/utils/ErrorUtils.scala       | 25 +++++++++++++++++++--
 .../connect/service/AddArtifactsHandlerSuite.scala | 20 +++++++++++++++--
 python/pyspark/sql/connect/client/core.py          | 15 ++++++++-----
 .../sql/tests/connect/client/test_artifact.py      | 20 +++++++++++++++++
 5 files changed, 93 insertions(+), 13 deletions(-)

diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index e664e07dce1..ea3b578be3b 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -24,7 +24,6 @@ import scala.collection.mutable
 import scala.util.control.NonFatal
 
 import com.google.common.io.CountingOutputStream
-import io.grpc.StatusRuntimeException
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
@@ -32,6 +31,7 @@ import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse
 import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
 import org.apache.spark.sql.artifact.ArtifactManager
 import org.apache.spark.sql.artifact.util.ArtifactUtils
+import org.apache.spark.sql.connect.utils.ErrorUtils
 import org.apache.spark.util.Utils
 
 /**
@@ -51,7 +51,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver: 
StreamObserver[AddAr
   private var chunkedArtifact: StagedChunkedArtifact = _
   private var holder: SessionHolder = _
 
-  override def onNext(req: AddArtifactsRequest): Unit = {
+  override def onNext(req: AddArtifactsRequest): Unit = try {
     if (this.holder == null) {
       this.holder = SparkConnectService.getOrCreateIsolatedSession(
         req.getUserContext.getUserId,
@@ -78,6 +78,17 @@ class SparkConnectAddArtifactsHandler(val responseObserver: 
StreamObserver[AddAr
     } else {
       throw new UnsupportedOperationException(s"Unsupported data transfer 
request: $req")
     }
+  } catch {
+    ErrorUtils.handleError(
+      "addArtifacts.onNext",
+      responseObserver,
+      holder.userId,
+      holder.sessionId,
+      None,
+      false,
+      Some(() => {
+        cleanUpStagedArtifacts()
+      }))
   }
 
   override def onError(throwable: Throwable): Unit = {
@@ -128,7 +139,16 @@ class SparkConnectAddArtifactsHandler(val 
responseObserver: StreamObserver[AddAr
       responseObserver.onNext(builder.build())
       responseObserver.onCompleted()
     } catch {
-      case e: StatusRuntimeException => onError(e)
+      ErrorUtils.handleError(
+        "addArtifacts.onComplete",
+        responseObserver,
+        holder.userId,
+        holder.sessionId,
+        None,
+        false,
+        Some(() => {
+          cleanUpStagedArtifacts()
+        }))
     }
   }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 7cb555ca47e..703b11c0c73 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -153,7 +153,15 @@ private[connect] object ErrorUtils extends Logging {
       .build()
   }
 
-  private def buildStatusFromThrowable(
+  /**
+   * This is a helper method that can be used by any GRPC handler to convert 
existing Throwables
+   * into GRPC conform status objects.
+   *
+   * @param st
+   * @param sessionHolderOpt
+   * @return
+   */
+  private[connect] def buildStatusFromThrowable(
       st: Throwable,
       sessionHolderOpt: Option[SessionHolder]): RPCStatus = {
     val errorInfo = ErrorInfo
@@ -221,6 +229,17 @@ private[connect] object ErrorUtils extends Logging {
    *   String value indicating the operation type (analysis, execution)
    * @param observer
    *   The GRPC response observer.
+   * @param userId
+   *   The user id.
+   * @param sessionId
+   *   The session id.
+   * @param events
+   *   The ExecuteEventsManager if present to report about failures.
+   * @param isInterrupted
+   *   Whether the error is caused by an interruption or during execution.
+   * @param callback
+   *   Optional callback to be called after the error has been sent that 
allows to caller to
+   *   execute additional cleanup logic.
    * @tparam V
    * @return
    */
@@ -230,7 +249,8 @@ private[connect] object ErrorUtils extends Logging {
       userId: String,
       sessionId: String,
       events: Option[ExecuteEventsManager] = None,
-      isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = {
+      isInterrupted: Boolean = false,
+      callback: Option[() => Unit] = None): PartialFunction[Throwable, Unit] = 
{
 
     // SessionHolder may not be present, e.g. if the session was already 
closed.
     // When SessionHolder is not present error details will not be available 
for FetchErrorDetails.
@@ -281,6 +301,7 @@ private[connect] object ErrorUtils extends Logging {
             executeEventsManager.postFailed(wrapped.getMessage)
           }
         }
+        callback.foreach(_.apply())
         observer.onError(wrapped)
       }
   }
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
index 2a65032beef..e681aa4726f 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/AddArtifactsHandlerSuite.scala
@@ -26,6 +26,10 @@ import scala.concurrent.duration._
 import scala.jdk.CollectionConverters._
 
 import com.google.protobuf.ByteString
+import com.google.rpc.ErrorInfo
+import io.grpc.Status.Code
+import io.grpc.StatusRuntimeException
+import io.grpc.protobuf.StatusProto
 import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
@@ -371,9 +375,15 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
       val name = "/absolute/path/"
       val request = createDummyArtifactRequests(name)
       request.foreach { req =>
-        intercept[IllegalArgumentException] {
+        val e = intercept[StatusRuntimeException] {
           handler.onNext(req)
         }
+        assert(e.getStatus.getCode == Code.INTERNAL)
+        val statusProto = StatusProto.fromThrowable(e)
+        assert(statusProto.getDetailsCount == 1)
+        val details = statusProto.getDetails(0)
+        val info = details.unpack(classOf[ErrorInfo])
+        assert(info.getReason.contains("java.lang.IllegalArgumentException"))
       }
       handler.onCompleted()
     } finally {
@@ -388,9 +398,15 @@ class AddArtifactsHandlerSuite extends SharedSparkSession 
with ResourceHelper {
       val names = Seq("..", "../sibling", "../nephew/directory", "a/../../b", 
"x/../y/../..")
       val request = names.flatMap(createDummyArtifactRequests)
       request.foreach { req =>
-        intercept[IllegalArgumentException] {
+        val e = intercept[StatusRuntimeException] {
           handler.onNext(req)
         }
+        assert(e.getStatus.getCode == Code.INTERNAL)
+        val statusProto = StatusProto.fromThrowable(e)
+        assert(statusProto.getDetailsCount == 1)
+        val details = statusProto.getDetails(0)
+        val info = details.unpack(classOf[ErrorInfo])
+        assert(info.getReason.contains("java.lang.IllegalArgumentException"))
       }
       handler.onCompleted()
     } finally {
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index f037e968be0..e36b7d74a78 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1575,12 +1575,15 @@ class SparkConnectClient(object):
             raise SparkConnectGrpcException(str(rpc_error)) from None
 
     def add_artifacts(self, *paths: str, pyfile: bool, archive: bool, file: 
bool) -> None:
-        for path in paths:
-            for attempt in self._retrying():
-                with attempt:
-                    self._artifact_manager.add_artifacts(
-                        path, pyfile=pyfile, archive=archive, file=file
-                    )
+        try:
+            for path in paths:
+                for attempt in self._retrying():
+                    with attempt:
+                        self._artifact_manager.add_artifacts(
+                            path, pyfile=pyfile, archive=archive, file=file
+                        )
+        except Exception as error:
+            self._handle_error(error)
 
     def copy_from_local_to_fs(self, local_path: str, dest_path: str) -> None:
         for attempt in self._retrying():
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py 
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 7fde0958e38..69a9525c65e 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -20,6 +20,7 @@ import tempfile
 import unittest
 import os
 
+from pyspark.errors.exceptions.connect import SparkConnectGrpcException
 from pyspark.sql import SparkSession
 from pyspark.testing.connectutils import ReusedConnectTestCase, 
should_test_connect
 from pyspark.testing.utils import SPARK_HOME
@@ -56,6 +57,25 @@ class ArtifactTestsMixin:
             
SparkSession.builder.remote(f"sc://localhost:{ChannelBuilder.default_port()}").create()
         )
 
+    def test_artifacts_cannot_be_overwritten(self):
+        with tempfile.TemporaryDirectory() as d:
+            pyfile_path = os.path.join(d, "my_pyfile.py")
+            with open(pyfile_path, "w+") as f:
+                f.write("my_func = lambda: 10")
+
+            self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+            # Writing the same file twice is fine, and should not throw.
+            self.spark.addArtifacts(pyfile_path, pyfile=True)
+
+            with open(pyfile_path, "w+") as f:
+                f.write("my_func = lambda: 11")
+
+            with self.assertRaisesRegex(
+                SparkConnectGrpcException, "\\(java.lang.RuntimeException\\) 
Duplicate Artifact"
+            ):
+                self.spark.addArtifacts(pyfile_path, pyfile=True)
+
     def check_add_zipped_package(self, spark_session):
         with tempfile.TemporaryDirectory() as d:
             package_path = os.path.join(d, "my_zipfile")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to