This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 10b448402b9 [SPARK-41533][CONNECT] Proper Error Handling for Spark Connect Server / Client 10b448402b9 is described below commit 10b448402b9e142db4a8d8c7989478a0c5d04315 Author: Martin Grund <martin.gr...@databricks.com> AuthorDate: Thu Dec 29 09:04:21 2022 +0900 [SPARK-41533][CONNECT] Proper Error Handling for Spark Connect Server / Client ### What changes were proposed in this pull request? This PR improves the error handling on the Spark Connect server and client side. First, this patch moves the error handling logic on the server into a common error handler partial function that differentiates between the internal Spark errors and other runtime errors. For custom Spark exceptions, the actual internal error is wrapped into a Google RPC Status and sent as trailing metadata to the client. On the client side, similarly, the error handling is moved into a common function. All GRPC errors are wrapped into custom exceptions to avoid presenting the user with confusing GRPC errors. If available the attached RPC status is extracted and added to the exception. Lastly, this patch adds basic logging functionality that can be enabled using the environment variable `SPARK_CONNECT_LOG_LEVEL` and can be set to `info`, `warn`, `error`, and `debug`. ### Why are the changes needed? Usability ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #39212 from grundprinzip/SPARK-41533. Authored-by: Martin Grund <martin.gr...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .github/workflows/build_and_test.yml | 2 +- connector/connect/README.md | 2 +- .../sql/connect/service/SparkConnectService.scala | 87 ++++++-- .../service/SparkConnectStreamHandler.scala | 9 +- .../connect/planner/SparkConnectServiceSuite.scala | 4 +- dev/create-release/spark-rm/Dockerfile | 2 +- dev/infra/Dockerfile | 2 +- dev/lint-python | 2 + dev/requirements.txt | 4 + python/docs/source/getting_started/install.rst | 20 +- python/mypy.ini | 1 + python/pyspark/sql/connect/client.py | 234 ++++++++++++++++++--- python/pyspark/sql/connect/dataframe.py | 2 +- .../sql/tests/connect/test_connect_basic.py | 29 ++- .../sql/tests/connect/test_connect_function.py | 20 +- python/pyspark/testing/connectutils.py | 25 ++- python/setup.py | 5 +- 17 files changed, 364 insertions(+), 86 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 5bd2fef9b0c..443fbf47942 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -589,7 +589,7 @@ jobs: # See also https://issues.apache.org/jira/browse/SPARK-38279. python3.9 -m pip install 'sphinx<3.1.0' mkdocs pydata_sphinx_theme ipython nbsphinx numpydoc 'jinja2<3.0.0' 'markupsafe==2.0.1' 'pyzmq<24.0.0' python3.9 -m pip install ipython_genutils # See SPARK-38517 - python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.5' 'mypy-protobuf==3.3.0' + python3.9 -m pip install sphinx_plotly_directive 'numpy>=1.20.0' pyarrow pandas 'plotly>=4.8' 'grpcio==1.48.1' 'protobuf==3.19.5' 'mypy-protobuf==3.3.0' 'grpc-stubs==1.24.11' 'googleapis-common-protos-stubs==2.2.0' python3.9 -m pip install 'docutils<0.18.0' # See SPARK-39421 apt-get update -y apt-get install -y ruby ruby-dev diff --git a/connector/connect/README.md b/connector/connect/README.md index d30f65ffb5d..d5cc767c744 100644 --- a/connector/connect/README.md +++ b/connector/connect/README.md @@ -90,7 +90,7 @@ To use the release version of Spark Connect: ### Generate proto generated files for the Python client 1. Install `buf version 1.11.0`: https://docs.buf.build/installation -2. Run `pip install grpcio==1.48.1 protobuf==3.19.5 mypy-protobuf==3.3.0` +2. Run `pip install grpcio==1.48.1 protobuf==3.19.5 mypy-protobuf==3.3.0 googleapis-common-protos==1.56.4 grpcio-status==1.48.1` 3. Run `./connector/connect/dev/generate_protos.sh` 4. Optional Check `./dev/check-codegen-python.py` diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index bfcea3d2252..61f035630f7 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -20,19 +20,23 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import com.google.common.base.Ticker import com.google.common.cache.CacheBuilder +import com.google.protobuf.{Any => ProtoAny} +import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus} import io.grpc.{Server, Status} import io.grpc.netty.NettyServerBuilder +import io.grpc.protobuf.StatusProto import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkEnv, SparkThrowable} import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{AnalyzePlanRequest, AnalyzePlanResponse, ExecutePlanRequest, ExecutePlanResponse, SparkConnectServiceGrpc} import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Dataset, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, SparkSession} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_BINDING_PORT import org.apache.spark.sql.connect.planner.{DataTypeProtoConverter, SparkConnectPlanner} import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExplainMode, ExtendedMode, FormattedMode, SimpleMode} @@ -49,6 +53,71 @@ class SparkConnectService(debug: Boolean) extends SparkConnectServiceGrpc.SparkConnectServiceImplBase with Logging { + private def buildStatusFromThrowable[A <: Throwable with SparkThrowable](st: A): RPCStatus = { + val t = Option(st.getCause).getOrElse(st) + RPCStatus + .newBuilder() + .setCode(RPCCode.INTERNAL_VALUE) + .addDetails( + ProtoAny.pack( + ErrorInfo + .newBuilder() + .setReason(t.getClass.getName) + .setDomain("org.apache.spark") + .build())) + .setMessage(t.getLocalizedMessage) + .build() + } + + /** + * Common exception handling function for the Analysis and Execution methods. Closes the stream + * after the error has been sent. + * + * @param opType + * String value indicating the operation type (analysis, execution) + * @param observer + * The GRPC response observer. + * @tparam V + * @return + */ + private def handleError[V]( + opType: String, + observer: StreamObserver[V]): PartialFunction[Throwable, Unit] = { + case ae: AnalysisException => + logError(s"Error during: $opType", ae) + val status = RPCStatus + .newBuilder() + .setCode(RPCCode.INTERNAL_VALUE) + .addDetails( + ProtoAny.pack( + ErrorInfo + .newBuilder() + .setReason(ae.getClass.getName) + .setDomain("org.apache.spark") + .putMetadata("message", ae.getSimpleMessage) + .putMetadata("plan", Option(ae.plan).flatten.map(p => s"$p").getOrElse("")) + .build())) + .setMessage(ae.getLocalizedMessage) + .build() + observer.onError(StatusProto.toStatusRuntimeException(status)) + case st: SparkThrowable => + logError(s"Error during: $opType", st) + val status = buildStatusFromThrowable(st) + observer.onError(StatusProto.toStatusRuntimeException(status)) + case NonFatal(nf) => + logError(s"Error during: $opType", nf) + val status = RPCStatus + .newBuilder() + .setCode(RPCCode.INTERNAL_VALUE) + .setMessage(nf.getLocalizedMessage) + .build() + observer.onError(StatusProto.toStatusRuntimeException(status)) + case e: Throwable => + logError(s"Error during: $opType", e) + observer.onError( + Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) + } + /** * This is the main entry method for Spark Connect and all calls to execute a plan. * @@ -64,12 +133,7 @@ class SparkConnectService(debug: Boolean) responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { try { new SparkConnectStreamHandler(responseObserver).handle(request) - } catch { - case e: Throwable => - log.error("Error executing plan.", e) - responseObserver.onError( - Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) - } + } catch handleError("execute", observer = responseObserver) } /** @@ -114,12 +178,7 @@ class SparkConnectService(debug: Boolean) response.setClientId(request.getClientId) responseObserver.onNext(response.build()) responseObserver.onCompleted() - } catch { - case e: Throwable => - log.error("Error analyzing plan.", e) - responseObserver.onError( - Status.UNKNOWN.withCause(e).withDescription(e.getLocalizedMessage).asRuntimeException()) - } + } catch handleError("analyze", observer = responseObserver) } def handleAnalyzePlanRequest( diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index 9631b93f6e9..9c1a8ca4dc4 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ -import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -128,12 +127,8 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[ExecutePlanResp } partitions(currentPartitionId) = null - error.foreach { - case NonFatal(e) => - responseObserver.onError(e) - logError("Error while processing query.", e) - return - case other => throw other + error.foreach { case other => + throw other } part } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 6dcce0926dc..9c5df253aea 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.connect.planner import scala.collection.mutable +import io.grpc.StatusRuntimeException import io.grpc.stub.StreamObserver import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.{BigIntVector, Float8Vector} import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.spark.SparkException import org.apache.spark.connect.proto import org.apache.spark.sql.connect.dsl.MockRemoteSession import org.apache.spark.sql.connect.dsl.plans._ @@ -185,7 +185,7 @@ class SparkConnectServiceSuite extends SharedSparkSession { } override def onError(throwable: Throwable): Unit = { - assert(throwable.isInstanceOf[SparkException]) + assert(throwable.isInstanceOf[StatusRuntimeException]) } override def onCompleted(): Unit = { diff --git a/dev/create-release/spark-rm/Dockerfile b/dev/create-release/spark-rm/Dockerfile index c65a0e1c759..38c64601882 100644 --- a/dev/create-release/spark-rm/Dockerfile +++ b/dev/create-release/spark-rm/Dockerfile @@ -42,7 +42,7 @@ ARG APT_INSTALL="apt-get install --no-install-recommends -y" # We should use the latest Sphinx version once this is fixed. # TODO(SPARK-35375): Jinja2 3.0.0+ causes error when building with Sphinx. # See also https://issues.apache.org/jira/browse/SPARK-35375. -ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 protobuf==4.21.6" +ARG PIP_PKGS="sphinx==3.0.4 mkdocs==1.1.2 numpy==1.19.4 pydata_sphinx_theme==0.4.1 ipython==7.19.0 nbsphinx==0.8.0 numpydoc==1.1.0 jinja2==2.11.3 twine==3.4.1 sphinx-plotly-directive==0.1.3 pandas==1.1.5 pyarrow==3.0.0 plotly==5.4.0 markupsafe==2.0.1 docutils<0.17 grpcio==1.48.1 protobuf==4.21.6 grpcio-status==1.48.1 googleapis-common-protos==1.56.4" ARG GEM_PKGS="bundler:2.2.9" # Install extra needed repos and refresh. diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 1c326d437c7..92cb75360c1 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -68,4 +68,4 @@ RUN pypy3 -m pip install numpy 'pandas<=1.5.2' scipy coverage matplotlib RUN python3.9 -m pip install numpy pyarrow 'pandas<=1.5.2' scipy unittest-xml-reporting plotly>=4.8 sklearn 'mlflow>=1.0' coverage matplotlib openpyxl 'memory-profiler==0.60.0' # Add Python deps for Spark Connect. -RUN python3.9 -m pip install grpcio protobuf +RUN python3.9 -m pip install grpcio protobuf googleapis-common-protos grpcio-status diff --git a/dev/lint-python b/dev/lint-python index 806b7572dc6..59ce71980d9 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -69,6 +69,7 @@ function mypy_annotation_test { echo "starting mypy annotations test..." MYPY_REPORT=$( ($MYPY_BUILD \ + --namespace-packages \ --config-file python/mypy.ini \ --cache-dir /tmp/.mypy_cache/ \ python/pyspark) 2>&1) @@ -127,6 +128,7 @@ function mypy_examples_test { echo "starting mypy examples test..." MYPY_REPORT=$( (MYPYPATH=python $MYPY_BUILD \ + --namespace-packages \ --config-file python/mypy.ini \ --exclude "mllib/*" \ examples/src/main/python/) 2>&1) diff --git a/dev/requirements.txt b/dev/requirements.txt index f91e2fed713..c3911b57eb9 100644 --- a/dev/requirements.txt +++ b/dev/requirements.txt @@ -50,7 +50,11 @@ black==22.6.0 # Spark Connect (required) grpcio==1.48.1 +grpcio-status==1.48.1 protobuf==3.19.5 +googleapis-common-protos==1.56.4 # Spark Connect python proto generation plugin (optional) mypy-protobuf==3.3.0 +googleapis-common-protos-stubs==2.2.0 +grpc-stubs==1.24.11 diff --git a/python/docs/source/getting_started/install.rst b/python/docs/source/getting_started/install.rst index d3b24be3d49..eddee8e30e1 100644 --- a/python/docs/source/getting_started/install.rst +++ b/python/docs/source/getting_started/install.rst @@ -153,15 +153,17 @@ To install PySpark from source, refer to |building_spark|_. Dependencies ------------ -============= ========================= ====================================================================================== -Package Minimum supported version Note -============= ========================= ====================================================================================== -`py4j` 0.10.9.7 Required -`pandas` 1.0.5 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL -`pyarrow` 1.0.0 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL -`numpy` 1.15 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL -`grpc` 1.48.1 Required for Spark Connect -============= ========================= ====================================================================================== +========================== ========================= ====================================================================================== +Package Minimum supported version Note +========================== ========================= ====================================================================================== +`py4j` 0.10.9.7 Required +`pandas` 1.0.5 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL +`pyarrow` 1.0.0 Required for pandas API on Spark and Spark Connect; Optional for Spark SQL +`numpy` 1.15 Required for pandas API on Spark and MLLib DataFrame-based API; Optional for Spark SQL +`grpc` 1.48.1 Required for Spark Connect +`grpcio-status` 1.48.1 Required for Spark Connect +`googleapis-common-protos` 1.56.4 Required for Spark Connect +========================== ========================= ====================================================================================== Note that PySpark requires Java 8 or later with ``JAVA_HOME`` properly set. If using JDK 11, set ``-Dio.netty.tryReflectionSetAccessible=true`` for Arrow related features and refer diff --git a/python/mypy.ini b/python/mypy.ini index 603647bd3cd..dd1c1cd4875 100644 --- a/python/mypy.ini +++ b/python/mypy.ini @@ -22,6 +22,7 @@ disallow_untyped_defs = True show_error_codes = True warn_unused_ignores = True warn_redundant_casts = True +namespace_packages = True [mypy-pyspark.sql.connect.proto.*] ignore_errors = True diff --git a/python/pyspark/sql/connect/client.py b/python/pyspark/sql/connect/client.py index e258dbd92b4..e78c4de0f70 100644 --- a/python/pyspark/sql/connect/client.py +++ b/python/pyspark/sql/connect/client.py @@ -15,14 +15,19 @@ # limitations under the License. # +import logging import os import urllib.parse import uuid -from typing import Iterable, Optional, Any, Union, List, Tuple, Dict +from typing import Iterable, Optional, Any, Union, List, Tuple, Dict, NoReturn, cast +import google.protobuf.message +from grpc_status import rpc_status import grpc import pandas +from google.protobuf import text_format import pyarrow as pa +from google.rpc import error_details_pb2 import pyspark.sql.connect.proto as pb2 import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib @@ -36,6 +41,50 @@ from pyspark.sql.types import ( ) +def _configure_logging() -> logging.Logger: + """Configure logging for the Spark Connect clients.""" + logger = logging.getLogger(__name__) + handler = logging.StreamHandler() + handler.setFormatter( + logging.Formatter(fmt="%(asctime)s %(process)d %(levelname)s %(funcName)s %(message)s") + ) + logger.addHandler(handler) + + # Check the environment variables for log levels: + if "SPARK_CONNECT_LOG_LEVEL" in os.environ: + logger.setLevel(os.getenv("SPARK_CONNECT_LOG_LEVEL", "error").upper()) + else: + logger.disabled = True + return logger + + +# Instantiate the logger based on the environment configuration. +logger = _configure_logging() + + +class SparkConnectException(Exception): + def __init__(self, message: str, reason: Optional[str] = None) -> None: + super(SparkConnectException, self).__init__(message) + self._reason = reason + self._message = message + + def __str__(self) -> str: + if self._reason is not None: + return f"({self._reason}) {self._message}" + else: + return self._message + + +class SparkConnectAnalysisException(SparkConnectException): + def __init__(self, reason: str, message: str, plan: str) -> None: + self._reason = reason + self._message = message + self._plan = plan + + def __str__(self) -> str: + return f"{self._message}\nPlan: {self._plan}" + + class ChannelBuilder: """ This is a helper class that is used to create a GRPC channel based on the given @@ -60,7 +109,18 @@ class ChannelBuilder: DEFAULT_PORT = 15002 - def __init__(self, url: str) -> None: + def __init__(self, url: str, channelOptions: Optional[List[Tuple[str, Any]]] = None) -> None: + """ + Constructs a new channel builder. This is used to create the proper GRPC channel from + the connection string. + + Parameters + ---------- + url : str + Spark Connect connection string + channelOptions: list of tuple, optional + Additional options that can be passed to the GRPC channel construction. + """ # Explicitly check the scheme of the URL. if url[:5] != "sc://": raise AttributeError("URL scheme must be set to `sc`.") @@ -74,6 +134,7 @@ class ChannelBuilder: f"Path component for connection URI must be empty: {self.url.path}" ) self._extract_attributes() + self._channel_options = channelOptions def _extract_attributes(self) -> None: if len(self.url.params) > 0: @@ -159,7 +220,8 @@ class ChannelBuilder: def toChannel(self) -> grpc.Channel: """ Applies the parameters of the connection string and creates a new - GRPC channel according to the configuration. + GRPC channel according to the configuration. Passes optional channel options to + construct the channel. Returns ------- @@ -176,7 +238,7 @@ class ChannelBuilder: use_secure = False if not use_secure: - return grpc.insecure_channel(destination) + return grpc.insecure_channel(destination, options=self._channel_options) else: # Default SSL Credentials. opt_token = self.params.get(ChannelBuilder.PARAM_TOKEN, None) @@ -186,9 +248,15 @@ class ChannelBuilder: composite_creds = grpc.composite_channel_credentials( ssl_creds, grpc.access_token_call_credentials(opt_token) ) - return grpc.secure_channel(destination, credentials=composite_creds) + return grpc.secure_channel( + destination, credentials=composite_creds, options=self._channel_options + ) else: - return grpc.secure_channel(destination, credentials=grpc.ssl_channel_credentials()) + return grpc.secure_channel( + destination, + credentials=grpc.ssl_channel_credentials(), + options=self._channel_options, + ) class MetricValue: @@ -272,7 +340,12 @@ class AnalyzeResult: class SparkConnectClient(object): """Conceptually the remote spark session that communicates with the server""" - def __init__(self, connectionString: str, userId: Optional[str] = None): + def __init__( + self, + connectionString: str, + userId: Optional[str] = None, + channelOptions: Optional[List[Tuple[str, Any]]] = None, + ): """ Creates a new SparkSession for the Spark Connect interface. @@ -288,7 +361,7 @@ class SparkConnectClient(object): takes precedence. """ # Parse the connection string. - self._builder = ChannelBuilder(connectionString) + self._builder = ChannelBuilder(connectionString, channelOptions) self._user_id = None # Generate a unique session ID for this client. This UUID must be unique to allow # concurrent Spark sessions of the same user. If the channel is closed, creating @@ -303,6 +376,7 @@ class SparkConnectClient(object): self._channel = self._builder.toChannel() self._stub = grpc_lib.SparkConnectServiceStub(self._channel) + # Configure logging for the SparkConnect client. def register_udf( self, function: Any, return_type: Union[str, pyspark.sql.types.DataType] @@ -312,6 +386,7 @@ class SparkConnectClient(object): name = f"fun_{uuid.uuid4().hex}" fun = pb2.CreateScalarFunction() fun.parts.append(name) + logger.info(f"Registering UDF: {self._proto_to_string(fun)}") fun.serialized_function = cloudpickle.dumps((function, return_type)) req = self._execute_plan_request_with_metadata() @@ -331,7 +406,8 @@ class SparkConnectClient(object): for x in metrics.metrics ] - def _to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": + def to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame": + logger.info(f"Executing plan {self._proto_to_string(plan)}") req = self._execute_plan_request_with_metadata() req.plan.CopyFrom(plan) return self._execute_and_fetch(req) @@ -339,7 +415,22 @@ class SparkConnectClient(object): def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) -> DataType: return types.proto_schema_to_pyspark_data_type(schema) + def _proto_to_string(self, p: google.protobuf.message.Message) -> str: + """ + Helper method to generate a one line string representation of the plan. + Parameters + ---------- + p : google.protobuf.message.Message + Generic Message type + + Returns + ------- + Single line string of the serialized proto message. + """ + return text_format.MessageToString(p, as_one_line=True) + def schema(self, plan: pb2.Plan) -> StructType: + logger.info(f"Schema for plan: {self._proto_to_string(plan)}") proto_schema = self._analyze(plan).schema # Server side should populate the struct field which is the schema. assert proto_schema.HasField("struct") @@ -355,10 +446,12 @@ class SparkConnectClient(object): return StructType(fields) def explain_string(self, plan: pb2.Plan, explain_mode: str = "extended") -> str: + logger.info(f"Explain (mode={explain_mode}) for plan {self._proto_to_string(plan)}") result = self._analyze(plan, explain_mode) return result.explain_string def execute_command(self, command: pb2.Command) -> None: + logger.info(f"Execute command for command {self._proto_to_string(command)}") req = self._execute_plan_request_with_metadata() if self._user_id: req.user_context.user_id = self._user_id @@ -386,6 +479,20 @@ class SparkConnectClient(object): return req def _analyze(self, plan: pb2.Plan, explain_mode: str = "extended") -> AnalyzeResult: + """ + Call the analyze RPC of Spark Connect. + + Parameters + ---------- + plan : :class:`pyspark.sql.connect.proto.Plan` + Proto representation of the plan. + explain_mode : str + Explain mode + + Returns + ------- + The result of the analyze call. + """ req = self._analyze_plan_request_with_metadata() req.plan.CopyFrom(plan) if explain_mode not in ["simple", "extended", "codegen", "cost", "formatted"]: @@ -406,36 +513,64 @@ class SparkConnectClient(object): else: # formatted req.explain.explain_mode = pb2.Explain.ExplainMode.FORMATTED - resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) - if resp.client_id != self._session_id: - raise ValueError("Received incorrect session identifier for request.") - return AnalyzeResult.fromProto(resp) + try: + resp = self._stub.AnalyzePlan(req, metadata=self._builder.metadata()) + if resp.client_id != self._session_id: + raise SparkConnectException("Received incorrect session identifier for request.") + return AnalyzeResult.fromProto(resp) + except grpc.RpcError as rpc_error: + self._handle_error(rpc_error) def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch) -> "pandas.DataFrame": with pa.ipc.open_stream(arrow_batch.data) as rd: return rd.read_pandas() def _execute(self, req: pb2.ExecutePlanRequest) -> None: - for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): - if b.client_id != self._session_id: - raise ValueError("Received incorrect session identifier for request.") - continue - return + """ + Execute the passed request `req` and drop all results. + + Parameters + ---------- + req : pb2.ExecutePlanRequest + Proto representation of the plan. + + """ + logger.info("Execute") + try: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + if b.client_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request." + ) + continue + except grpc.RpcError as rpc_error: + self._handle_error(rpc_error) def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) -> "pandas.DataFrame": + logger.info("ExecuteAndFetch") import pandas as pd m: Optional[pb2.ExecutePlanResponse.Metrics] = None result_dfs = [] - for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): - if b.client_id != self._session_id: - raise ValueError("Received incorrect session identifier for request.") - if b.metrics is not None: - m = b.metrics - if b.HasField("arrow_batch"): - pb = self._process_batch(b.arrow_batch) - result_dfs.append(pb) + try: + for b in self._stub.ExecutePlan(req, metadata=self._builder.metadata()): + if b.client_id != self._session_id: + raise SparkConnectException( + "Received incorrect session identifier for request." + ) + if b.metrics is not None: + logger.debug("Received metric batch.") + m = b.metrics + if b.HasField("arrow_batch"): + logger.debug( + f"Received arrow batch rows={b.arrow_batch.row_count} " + f"size={len(b.arrow_batch.data)}" + ) + pb = self._process_batch(b.arrow_batch) + result_dfs.append(pb) + except grpc.RpcError as rpc_error: + self._handle_error(rpc_error) assert len(result_dfs) > 0 @@ -451,3 +586,50 @@ class SparkConnectClient(object): if m is not None: df.attrs["metrics"] = self._build_metrics(m) return df + + def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn: + """ + Error handling helper for dealing with GRPC Errors. On the server side, certain + exceptions are enriched with additional RPC Status information. These are + unpacked in this function and put into the exception. + + To avoid overloading the user with GRPC errors, this message explicitly + swallows the error context from the call. This GRPC Error is logged however, + and can be enabled. + + Parameters + ---------- + rpc_error : grpc.RpcError + RPC Error containing the details of the exception. + + Returns + ------- + Throws the appropriate internal Python exception. + """ + logger.exception("GRPC Error received") + # We have to cast the value here because, a RpcError is a Call as well. + # https://grpc.github.io/grpc/python/grpc.html#grpc.UnaryUnaryMultiCallable.__call__ + status = rpc_status.from_call(cast(grpc.Call, rpc_error)) + if status: + for d in status.details: + if d.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + info = error_details_pb2.ErrorInfo() + d.Unpack(info) + if info.reason == "org.apache.spark.sql.AnalysisException": + raise SparkConnectAnalysisException( + info.reason, info.metadata["message"], info.metadata["plan"] + ) from None + else: + raise SparkConnectException(status.message, info.reason) from None + + raise SparkConnectException(status.message) from None + else: + raise SparkConnectException(str(rpc_error)) from None + + +__all__ = [ + "ChannelBuilder", + "SparkConnectClient", + "SparkConnectException", + "SparkConnectAnalysisException", +] diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py index 313d46ceb2f..08db6b61871 100644 --- a/python/pyspark/sql/connect/dataframe.py +++ b/python/pyspark/sql/connect/dataframe.py @@ -907,7 +907,7 @@ class DataFrame: if self._session is None: raise Exception("Cannot collect on empty session.") query = self._plan.to_proto(self._session.client) - return self._session.client._to_pandas(query) + return self._session.client.to_pandas(query) toPandas.__doc__ = PySparkDataFrame.toPandas.__doc__ diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py index 302a5a96010..da6d5afd1cd 100644 --- a/python/pyspark/sql/tests/connect/test_connect_basic.py +++ b/python/pyspark/sql/tests/connect/test_connect_basic.py @@ -33,6 +33,7 @@ import pyspark.sql.functions from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.pandasutils import PandasOnSparkTestCase +from pyspark.sql.connect.client import SparkConnectException, SparkConnectAnalysisException if should_test_connect: import grpc @@ -116,6 +117,12 @@ class SparkConnectSQLTestCase(PandasOnSparkTestCase, ReusedPySparkTestCase, SQLT class SparkConnectTests(SparkConnectSQLTestCase): + def test_error_handling(self): + # SPARK-41533 Proper error handling for Spark Connect + df = self.connect.range(10).select("id2") + with self.assertRaises(SparkConnectAnalysisException): + df.collect() + def test_simple_read(self): df = self.connect.read.table(self.tbl_name) data = df.limit(10).toPandas() @@ -262,12 +269,12 @@ class SparkConnectTests(SparkConnectSQLTestCase): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() def test_with_local_list(self): @@ -299,12 +306,12 @@ class SparkConnectTests(SparkConnectSQLTestCase): ): self.connect.createDataFrame(data, ["a", "b", "c", "d", "e"]) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.createDataFrame( data, "col1 magic_type, col2 int, col3 int, col4 int" ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.createDataFrame(data, "col1 int, col2 int, col3 int").show() def test_with_atom_type(self): @@ -457,7 +464,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): ] ) - with self.assertRaises(grpc.RpcError) as context: + with self.assertRaises(SparkConnectException) as context: self.connect.read.table(self.tbl_name).to(schema).toPandas() self.assertIn( """Column or field `name` is of type "STRING" while it's required to be "INT".""", @@ -679,7 +686,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): # Test when creating a view which is already exists but self.assertTrue(self.spark.catalog.tableExists("global_temp.view_1")) - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.sql("SELECT 1 AS X LIMIT 0").createGlobalTempView("view_1") def test_create_session_local_temp_view(self): @@ -691,7 +698,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.assertEqual(self.connect.sql("SELECT * FROM view_local_temp").count(), 0) # Test when creating a view which is already exists but - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.sql("SELECT 1 AS X LIMIT 0").createTempView("view_local_temp") def test_to_pandas(self): @@ -876,7 +883,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.connect.sql(query).replace({None: 1}, subset="a").toPandas() self.assertTrue("Mixed type replacements are not supported" in str(context.exception)) - with self.assertRaises(grpc.RpcError) as context: + with self.assertRaises(SparkConnectException) as context: self.connect.sql(query).replace({1: 2, 3: -1}, subset=("a", "x")).toPandas() self.assertIn( """Cannot resolve column name "x" among (a, b, c)""", str(context.exception) @@ -957,7 +964,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): ) # Hint with unsupported parameter values - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id+1").toPandas() # Hint with unsupported parameter types @@ -965,7 +972,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): self.connect.read.table(self.tbl_name).hint("REPARTITION", 1.1).toPandas() # Hint with wrong combination - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): self.connect.read.table(self.tbl_name).hint("REPARTITION", "id", 3).toPandas() def test_empty_dataset(self): @@ -1084,7 +1091,7 @@ class SparkConnectTests(SparkConnectSQLTestCase): ) self.assertEqual("name", col0) - with self.assertRaises(grpc.RpcError) as exc: + with self.assertRaises(SparkConnectException) as exc: self.connect.range(1, 10).select(col("id").alias("this", "is", "not")).collect() self.assertIn("(this, is, not)", str(exc.exception)) diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index c9d770f1399..edf58947712 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -22,9 +22,9 @@ from pyspark.testing.pandasutils import PandasOnSparkTestCase from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import ReusedPySparkTestCase from pyspark.testing.sqlutils import SQLTestUtils +from pyspark.sql.connect.client import SparkConnectException if should_test_connect: - import grpc from pyspark.sql.connect.session import SparkSession as RemoteSparkSession @@ -818,7 +818,7 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): cdf.select(CF.rank().over(cdf.a)) # invalid window function - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select(cdf.b.over(CW.orderBy("b"))).show() # invalid window frame @@ -832,34 +832,34 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): CF.lead("c", 1), CF.ntile(1), ]: - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.unboundedPreceding, CW.currentRow)) ).show() # Function 'cume_dist' requires Windowframe(RangeFrame, UnboundedPreceding, CurrentRow) ccol = CF.cume_dist() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rangeBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.currentRow, CW.currentRow + 123)) ).show() - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select( ccol.over(CW.orderBy("b").rowsBetween(CW.unboundedPreceding, CW.currentRow)) ).show() @@ -1964,11 +1964,11 @@ class SparkConnectFunctionTests(SparkConnectFuncTestCase): sdf = self.spark.sql(query) # test assert_true - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select(CF.assert_true(cdf.a > 0, "a should be positive!")).show() # test raise_error - with self.assertRaises(grpc.RpcError): + with self.assertRaises(SparkConnectException): cdf.select(CF.raise_error("a should be positive!")).show() # test crc32 diff --git a/python/pyspark/testing/connectutils.py b/python/pyspark/testing/connectutils.py index dcbc09f2210..bec116b5f79 100644 --- a/python/pyspark/testing/connectutils.py +++ b/python/pyspark/testing/connectutils.py @@ -34,8 +34,29 @@ except ImportError as e: grpc_requirement_message = str(e) have_grpc = grpc_requirement_message is None + +grpc_status_requirement_message = None +try: + import grpc_status +except ImportError as e: + grpc_status_requirement_message = str(e) +have_grpc_status = grpc_status_requirement_message is None + +googleapis_common_protos_requirement_message = None +try: + from google.rpc import error_details_pb2 +except ImportError as e: + googleapis_common_protos_requirement_message = str(e) +have_googleapis_common_protos = googleapis_common_protos_requirement_message is None + connect_not_compiled_message = None -if have_pandas and have_pyarrow and have_grpc: +if ( + have_pandas + and have_pyarrow + and have_grpc + and have_grpc_status + and have_googleapis_common_protos +): from pyspark.sql.connect import DataFrame from pyspark.sql.connect.plan import Read, Range, SQL from pyspark.testing.utils import search_jar @@ -62,6 +83,8 @@ connect_requirement_message = ( or pyarrow_requirement_message or grpc_requirement_message or connect_not_compiled_message + or googleapis_common_protos_requirement_message + or grpc_status_requirement_message ) should_test_connect: str = typing.cast(str, connect_requirement_message is None) diff --git a/python/setup.py b/python/setup.py index 4ba2740246a..54115359a60 100755 --- a/python/setup.py +++ b/python/setup.py @@ -114,6 +114,7 @@ if (in_spark): _minimum_pandas_version = "1.0.5" _minimum_pyarrow_version = "1.0.0" _minimum_grpc_version = "1.48.1" +_minimum_googleapis_common_protos_version = "1.56.4" class InstallCommand(install): @@ -280,7 +281,9 @@ try: 'connect': [ 'pandas>=%s' % _minimum_pandas_version, 'pyarrow>=%s' % _minimum_pyarrow_version, - 'grpc>=%s' % _minimum_grpc_version, + 'grpcio>=%s' % _minimum_grpc_version, + 'grpcio-status>=%s' % _minimum_grpc_version, + 'googleapis-common-protos>=%s' % _minimum_googleapis_common_protos_version, 'numpy>=1.15', ], }, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org