This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f53dc08230b7 [SPARK-46962][SS][PYTHON] Add interface for python streaming data source API and implement python worker to run python streaming data source f53dc08230b7 is described below commit f53dc08230b7a758227f02fd75d2b446721c139f Author: Chaoqin Li <chaoqin...@databricks.com> AuthorDate: Tue Mar 12 12:35:18 2024 +0900 [SPARK-46962][SS][PYTHON] Add interface for python streaming data source API and implement python worker to run python streaming data source ### What changes were proposed in this pull request? This is the first PR the implement the support to implement streaming data source through python API. Implement python worker to run python streaming data source and communicate with JVM through socket. Create a PythonMicrobatchStream to invoke RPC function call. This happens in the spark driver. For each python streaming data source instance there will be a long live python worker process created. Inside the python process, the python streaming reader will receive function call and parameter from JVM PythonMicroBatchStream and send back result through socket. ### Why are the changes needed? In preparation for support of development of streaming data source in Python. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. PythonMicroBatchStream plan offset and partitions by invoking function call through socket correctly and handle error correctly. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #45023 from chaoqin-li1123/python_table. Lead-authored-by: Chaoqin Li <chaoqin...@databricks.com> Co-authored-by: Hyukjin Kwon <gurwls...@gmail.com> Co-authored-by: chaoqin-li1123 <55518381+chaoqin-li1...@users.noreply.github.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../src/main/resources/error/error-classes.json | 6 + docs/sql-error-conditions.md | 6 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/sql/datasource.py | 153 ++++++++++++++ .../streaming/python_streaming_source_runner.py | 167 +++++++++++++++ .../spark/sql/errors/QueryExecutionErrors.scala | 9 + .../v2/python/PythonMicroBatchStream.scala | 68 ++++++ .../datasources/v2/python/PythonScan.scala | 32 ++- .../v2/python/UserDefinedPythonDataSource.scala | 2 +- .../python/PythonStreamingSourceRunner.scala | 202 ++++++++++++++++++ .../execution/python/PythonDataSourceSuite.scala | 20 +- .../python/PythonStreamingDataSourceSuite.scala | 233 +++++++++++++++++++++ 12 files changed, 883 insertions(+), 20 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index afe81b8e9bea..3d130fdce301 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -3197,6 +3197,12 @@ ], "sqlState" : "38000" }, + "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : { + "message" : [ + "Failed when Python streaming data source perform <action>: <msg>" + ], + "sqlState" : "38000" + }, "RECURSIVE_PROTOBUF_SCHEMA" : { "message" : [ "Found recursive reference in Protobuf schema, which can not be processed by Spark by default: <fieldDescriptor>. try setting the option `recursive.fields.max.depth` 0 to 10. Going beyond 10 levels of recursion is not allowed." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 0695ed28b7fc..2cddb6a94c14 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1931,6 +1931,12 @@ Protobuf type not yet supported: `<protobufType>`. Failed to `<action>` Python data source `<type>`: `<msg>` +### PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR + +[SQLSTATE: 38000](sql-error-conditions-sqlstates.html#class-38-external-routine-exception) + +Failed when Python streaming data source perform `<action>`: `<msg>` + ### RECURSIVE_PROTOBUF_SCHEMA [SQLSTATE: 42K0G](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index c9a7cfbf356e..1e21ad3543e9 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -812,6 +812,11 @@ ERROR_CLASSES_JSON = ''' "Randomness of hash of string should be disabled via PYTHONHASHSEED." ] }, + "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR": { + "message": [ + "Failed when running Python streaming data source: <msg>" + ] + }, "PYTHON_VERSION_MISMATCH": { "message": [ "Python in worker has different version: <worker_version> than that in driver: <driver_version>, PySpark cannot run with different minor versions.", diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py index 1e50f8270243..b98729f04332 100644 --- a/python/pyspark/sql/datasource.py +++ b/python/pyspark/sql/datasource.py @@ -29,6 +29,7 @@ if TYPE_CHECKING: __all__ = [ "DataSource", "DataSourceReader", + "DataSourceStreamReader", "DataSourceWriter", "DataSourceRegistration", "InputPartition", @@ -159,6 +160,27 @@ class DataSource(ABC): message_parameters={"feature": "writer"}, ) + def streamReader(self, schema: StructType) -> "DataSourceStreamReader": + """ + Returns a ``DataSourceStreamReader`` instance for reading streaming data. + + The implementation is required for readable streaming data sources. + + Parameters + ---------- + schema : StructType + The schema of the data to be read. + + Returns + ------- + reader : DataSourceStreamReader + A reader instance for this streaming data source. + """ + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "streamReader"}, + ) + class InputPartition: """ @@ -298,6 +320,137 @@ class DataSourceReader(ABC): ... +class DataSourceStreamReader(ABC): + """ + A base class for streaming data source readers. Data source stream readers are responsible + for outputting data from a streaming data source. + + .. versionadded: 4.0.0 + """ + + def initialOffset(self) -> dict: + """ + Return the initial offset of the streaming data source. + A new streaming query starts reading data from the initial offset. + If Spark is restarting an existing query, it will restart from the check-pointed offset + rather than the initial one. + + Returns + ------- + dict + A dict or recursive dict whose key and value are primitive types, which includes + Integer, String and Boolean. + + Examples + -------- + >>> def initialOffset(self): + ... return {"parititon-1": {"index": 3, "closed": True}, "partition-2": {"index": 5}} + """ + + ... + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "initialOffset"}, + ) + + def latestOffset(self) -> dict: + """ + Returns the most recent offset available. + + Returns + ------- + dict + A dict or recursive dict whose key and value are primitive types, which includes + Integer, String and Boolean. + + Examples + -------- + >>> def latestOffset(self): + ... return {"parititon-1": {"index": 3, "closed": True}, "partition-2": {"index": 5}} + """ + ... + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "latestOffset"}, + ) + + def partitions(self, start: dict, end: dict) -> Sequence[InputPartition]: + """ + Returns a list of InputPartition given the start and end offsets. Each InputPartition + represents a data split that can be processed by one Spark task. + + Parameters + ---------- + start : dict + The start offset of the microbatch to plan partitioning. + end : dict + The end offset of the microbatch to plan partitioning. + + Returns + ------- + Sequence[InputPartition] + A sequence of partitions for this data source. Each partition value + must be an instance of `InputPartition` or a subclass of it. + """ + ... + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "partitions"}, + ) + + @abstractmethod + def read(self, partition: InputPartition) -> Iterator[Union[Tuple, Row]]: + """ + Generates data for a given partition and returns an iterator of tuples or rows. + + This method is invoked once per partition to read the data. Implementing + this method is required for stream reader. You can initialize any + non-serializable resources required for reading data from the data source + within this method. + + Notes + ----- + This method is static and stateless. You shouldn't access mutable class member + or keep in memory state between different invocations of read(). + + Parameters + ---------- + partition : InputPartition + The partition to read. It must be one of the partition values returned by + ``partitions()``. + + Returns + ------- + Iterator[Tuple] or Iterator[Row] + An iterator of tuples or rows. Each tuple or row will be converted to a row + in the final DataFrame. + """ + ... + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={"feature": "read"}, + ) + + def commit(self, end: dict) -> None: + """ + Informs the source that Spark has completed processing all data for offsets less than or + equal to `end` and will only request offsets greater than `end` in the future. + + Parameters + ---------- + end : dict + The latest offset that the streaming query has processed for this source. + """ + ... + + def stop(self) -> None: + """ + Stop this source and free any resources it has allocated. + Invoked when the streaming query terminated. + """ + ... + + class DataSourceWriter(ABC): """ A base class for data source writers. Data source writers are responsible for saving diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py b/python/pyspark/sql/streaming/python_streaming_source_runner.py new file mode 100644 index 000000000000..8dbac431a8ba --- /dev/null +++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py @@ -0,0 +1,167 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import sys +import json +from typing import IO + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.errors import IllegalArgumentException, PySparkAssertionError, PySparkRuntimeError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_int, + write_int, + write_with_length, + SpecialLengths, +) +from pyspark.sql.datasource import DataSource, DataSourceStreamReader +from pyspark.sql.types import ( + _parse_datatype_json_string, + StructType, +) +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + read_command, + pickleSer, + send_accumulator_updates, + setup_memory_limits, + setup_spark_files, + utf8_deserializer, +) + +INITIAL_OFFSET_FUNC_ID = 884 +LATEST_OFFSET_FUNC_ID = 885 +PARTITIONS_FUNC_ID = 886 +COMMIT_FUNC_ID = 887 + + +def initial_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: + offset = reader.initialOffset() + write_with_length(json.dumps(offset).encode("utf-8"), outfile) + + +def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: + offset = reader.latestOffset() + write_with_length(json.dumps(offset).encode("utf-8"), outfile) + + +def partitions_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> None: + start_offset = json.loads(utf8_deserializer.loads(infile)) + end_offset = json.loads(utf8_deserializer.loads(infile)) + partitions = reader.partitions(start_offset, end_offset) + # Return the serialized partition values. + write_int(len(partitions), outfile) + for partition in partitions: + pickleSer._write_with_length(partition, outfile) + + +def commit_func(reader: DataSourceStreamReader, infile: IO, outfile: IO) -> None: + end_offset = json.loads(utf8_deserializer.loads(infile)) + reader.commit(end_offset) + write_int(0, outfile) + + +def main(infile: IO, outfile: IO) -> None: + try: + check_python_version(infile) + setup_spark_files(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + _accumulatorRegistry.clear() + + # Receive the data source instance. + data_source = read_command(pickleSer, infile) + + if not isinstance(data_source, DataSource): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "a Python data source instance of type 'DataSource'", + "actual": f"'{type(data_source).__name__}'", + }, + ) + + # Receive the data source output schema. + schema_json = utf8_deserializer.loads(infile) + schema = _parse_datatype_json_string(schema_json) + if not isinstance(schema, StructType): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "an output schema of type 'StructType'", + "actual": f"'{type(schema).__name__}'", + }, + ) + + # Instantiate data source reader. + try: + reader = data_source.streamReader(schema=schema) + # Initialization succeed. + write_int(0, outfile) + outfile.flush() + + # handle method call from socket + while True: + func_id = read_int(infile) + if func_id == INITIAL_OFFSET_FUNC_ID: + initial_offset_func(reader, outfile) + elif func_id == LATEST_OFFSET_FUNC_ID: + latest_offset_func(reader, outfile) + elif func_id == PARTITIONS_FUNC_ID: + partitions_func(reader, infile, outfile) + elif func_id == COMMIT_FUNC_ID: + commit_func(reader, infile, outfile) + else: + raise IllegalArgumentException( + error_class="UNSUPPORTED_OPERATION", + message_parameters={ + "operation": "Function call id not recognized by stream reader" + }, + ) + outfile.flush() + except Exception as e: + error_msg = "data source {} throw exception: {}".format(data_source.name, e) + raise PySparkRuntimeError( + error_class="PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + message_parameters={"error": error_msg}, + ) + finally: + reader.stop() + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 8280826fafcb..32482015b3ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1758,6 +1758,15 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "outputMode" -> outputMode.toString())) } + def pythonStreamingDataSourceRuntimeError( + action: String, + message: String): SparkException = { + new SparkException( + errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + messageParameters = Map("action" -> action, "msg" -> message), + cause = null) + } + def invalidCatalogNameError(name: String): Throwable = { new SparkException( errorClass = "_LEGACY_ERROR_TEMP_2212", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala new file mode 100644 index 000000000000..4e77f33c24f0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.v2.python + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory} +import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.execution.python.PythonStreamingSourceRunner +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +case class PythonStreamingSourceOffset(json: String) extends Offset + +case class PythonStreamingSourcePartition(partition: Array[Byte]) extends InputPartition + +class PythonMicroBatchStream( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap + ) extends MicroBatchStream with Logging { + private def createDataSourceFunc = + ds.source.createPythonFunction( + ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)).dataSource) + + private val runner: PythonStreamingSourceRunner = + new PythonStreamingSourceRunner(createDataSourceFunc, outputSchema) + runner.init() + + override def initialOffset(): Offset = PythonStreamingSourceOffset(runner.initialOffset()) + + override def latestOffset(): Offset = PythonStreamingSourceOffset(runner.latestOffset()) + + override def planInputPartitions(start: Offset, end: Offset): Array[InputPartition] = { + runner.partitions(start.asInstanceOf[PythonStreamingSourceOffset].json, + end.asInstanceOf[PythonStreamingSourceOffset].json).map(PythonStreamingSourcePartition(_)) + } + + override def createReaderFactory(): PartitionReaderFactory = { + // TODO(SPARK-47107): fill in the implementation. + null + } + + override def commit(end: Offset): Unit = { + runner.commit(end.asInstanceOf[PythonStreamingSourceOffset].json) + } + + override def stop(): Unit = { + runner.stop() + } + + override def deserializeOffset(json: String): Offset = PythonStreamingSourceOffset(json) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala index 75cbe38b1397..bcddf66fc161 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.python import org.apache.spark.JobArtifactSet import org.apache.spark.sql.connector.metric.CustomMetric import org.apache.spark.sql.connector.read._ +import org.apache.spark.sql.connector.read.streaming.MicroBatchStream import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -27,9 +28,27 @@ class PythonScan( ds: PythonDataSourceV2, shortName: String, outputSchema: StructType, - options: CaseInsensitiveStringMap) extends Batch with Scan { + options: CaseInsensitiveStringMap) extends Scan { - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + override def toBatch: Batch = new PythonBatch(ds, shortName, outputSchema, options) + + override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = + new PythonMicroBatchStream(ds, shortName, outputSchema, options) + + override def description: String = "(Python)" + + override def readSchema(): StructType = outputSchema + + override def supportedCustomMetrics(): Array[CustomMetric] = + ds.source.createPythonMetrics() +} + +class PythonBatch( + ds: PythonDataSourceV2, + shortName: String, + outputSchema: StructType, + options: CaseInsensitiveStringMap) extends Batch { + private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) private lazy val infoInPython: PythonDataSourceReadInfo = { ds.source.createReadInfoInPython( @@ -45,13 +64,4 @@ class PythonScan( new PythonPartitionReaderFactory( ds.source, readerFunc, outputSchema, jobArtifactUUID) } - - override def toBatch: Batch = this - - override def description: String = "(Python)" - - override def readSchema(): StructType = outputSchema - - override def supportedCustomMetrics(): Array[CustomMetric] = - ds.source.createPythonMetrics() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala index 0e5f359ee76f..9b8219c4dc2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala @@ -150,7 +150,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { taskMetrics.map { case (k, v) => new PythonCustomTaskMetric(k, v)}.toArray } - private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = { + def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction = { SimplePythonFunction( command = pickledFunc.toImmutableArraySeq, envVars = dataSourceCls.envVars, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala new file mode 100644 index 000000000000..4862f8220f06 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonStreamingSourceRunner.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.spark.sql.execution.python + +import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ + +import org.apache.spark.SparkEnv +import org.apache.spark.api.python.{PythonFunction, PythonWorker, PythonWorkerFactory, PythonWorkerUtils, SpecialLengths} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.BUFFER_SIZE +import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.types.StructType + +object PythonStreamingSourceRunner { + // When the python process for python_streaming_source_runner receives one of the + // integers below, it will invoke the corresponding function of StreamReader instance. + val INITIAL_OFFSET_FUNC_ID = 884 + val LATEST_OFFSET_FUNC_ID = 885 + val PARTITIONS_FUNC_ID = 886 + val COMMIT_FUNC_ID = 887 +} + +/** + * This class is a proxy to invoke methods in Python DataSourceStreamReader from JVM. + * A runner spawns a python worker process. In the main function, set up communication + * between JVM and python process through socket and create a DataSourceStreamReader instance. + * In an infinite loop, the python worker process poll information(function name and parameters) + * from the socket, invoke the corresponding method of StreamReader and send return value to JVM. + */ +class PythonStreamingSourceRunner( + func: PythonFunction, + outputSchema: StructType) extends Logging { + val workerModule = "pyspark.sql.streaming.python_streaming_source_runner" + + private val conf = SparkEnv.get.conf + private val bufferSize: Int = conf.get(BUFFER_SIZE) + private val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) + + private val envVars: java.util.Map[String, String] = func.envVars + private val pythonExec: String = func.pythonExec + private var pythonWorker: Option[PythonWorker] = None + private var pythonWorkerFactory: Option[PythonWorkerFactory] = None + private val pythonVer: String = func.pythonVer + + private var dataOut: DataOutputStream = null + private var dataIn: DataInputStream = null + + import PythonStreamingSourceRunner._ + + /** + * Initializes the Python worker for running the streaming source. + */ + def init(): Unit = { + logInfo(s"Initializing Python runner pythonExec: $pythonExec") + val env = SparkEnv.get + + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) + + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + + val workerFactory = + new PythonWorkerFactory(pythonExec, workerModule, envVars.asScala.toMap) + val (worker: PythonWorker, _) = workerFactory.createSimpleWorker(blockingMode = true) + pythonWorker = Some(worker) + pythonWorkerFactory = Some(workerFactory) + + val stream = new BufferedOutputStream( + pythonWorker.get.channel.socket().getOutputStream, bufferSize) + dataOut = new DataOutputStream(stream) + + PythonWorkerUtils.writePythonVersion(pythonVer, dataOut) + + val pythonIncludes = func.pythonIncludes.asScala.toSet + PythonWorkerUtils.writeSparkFiles(Some("streaming_job"), pythonIncludes, dataOut) + + // Send the user function to python process + PythonWorkerUtils.writePythonFunction(func, dataOut) + + // Send output schema + PythonWorkerUtils.writeUTF(outputSchema.json, dataOut) + + dataOut.flush() + + dataIn = new DataInputStream( + new BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, bufferSize)) + + val initStatus = dataIn.readInt() + if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.pythonDataSourceError( + action = "plan", tpe = "initialize source", msg = msg) + } + } + + /** + * Invokes latestOffset() function of the stream reader and receive the return value. + */ + def latestOffset(): String = { + dataOut.writeInt(LATEST_OFFSET_FUNC_ID) + dataOut.flush() + val len = dataIn.readInt() + if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "latestOffset", msg) + } + PythonWorkerUtils.readUTF(len, dataIn) + } + + /** + * Invokes initialOffset() function of the stream reader and receive the return value. + */ + def initialOffset(): String = { + dataOut.writeInt(INITIAL_OFFSET_FUNC_ID) + dataOut.flush() + val len = dataIn.readInt() + if (len == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "initialOffset", msg) + } + PythonWorkerUtils.readUTF(len, dataIn) + } + + /** + * Invokes partitions(start, end) function of the stream reader and receive the return value. + */ + def partitions(start: String, end: String): Array[Array[Byte]] = { + dataOut.writeInt(PARTITIONS_FUNC_ID) + PythonWorkerUtils.writeUTF(start, dataOut) + PythonWorkerUtils.writeUTF(end, dataOut) + dataOut.flush() + // Receive the list of partitions, if any. + val pickledPartitions = ArrayBuffer.empty[Array[Byte]] + val numPartitions = dataIn.readInt() + if (numPartitions == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "planPartitions", msg) + } + for (_ <- 0 until numPartitions) { + val pickledPartition: Array[Byte] = PythonWorkerUtils.readBytes(dataIn) + pickledPartitions.append(pickledPartition) + } + pickledPartitions.toArray + } + + /** + * Invokes commit(end) function of the stream reader and receive the return value. + */ + def commit(end: String): Unit = { + dataOut.writeInt(COMMIT_FUNC_ID) + PythonWorkerUtils.writeUTF(end, dataOut) + dataOut.flush() + val status = dataIn.readInt() + if (status == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryExecutionErrors.pythonStreamingDataSourceRuntimeError( + action = "commitSource", msg) + } + } + + /** + * Stop the python worker process and invoke stop() on stream reader. + */ + def stop(): Unit = { + logInfo(s"Stopping streaming runner for module: $workerModule.") + try { + pythonWorkerFactory.foreach { factory => + pythonWorker.foreach { worker => + factory.stopWorker(worker) + factory.stop() + } + } + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index c2f5864b7345..b207afeae106 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -28,13 +28,9 @@ import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -class PythonDataSourceSuite extends QueryTest with SharedSparkSession { - import IntegratedUDFTestUtils._ - - setupTestData() +abstract class PythonDataSourceSuiteBase extends QueryTest with SharedSparkSession { - private def dataSourceName = "SimpleDataSource" - private val simpleDataSourceReaderScript: String = + protected val simpleDataSourceReaderScript: String = """ |from pyspark.sql.datasource import DataSourceReader, InputPartition |class SimpleDataSourceReader(DataSourceReader): @@ -45,8 +41,8 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { | yield (1, partition.value) | yield (2, partition.value) |""".stripMargin - private val staticSourceName = "custom_source" - private var tempDir: File = _ + protected val staticSourceName = "custom_source" + protected var tempDir: File = _ override def beforeAll(): Unit = { // Create a Python Data Source package before starting up the Spark Session @@ -90,6 +86,14 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { } } + setupTestData() + + protected def dataSourceName = "SimpleDataSource" +} + +class PythonDataSourceSuite extends PythonDataSourceSuiteBase { + import IntegratedUDFTestUtils._ + test("SPARK-45917: automatic registration of Python Data Source") { assume(shouldTestPandasUDFs) val df = spark.read.format(staticSourceName).load() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala new file mode 100644 index 000000000000..f022e353edd7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonStreamingDataSourceSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.IntegratedUDFTestUtils.{createUserDefinedPythonDataSource, shouldTestPandasUDFs} +import org.apache.spark.sql.execution.datasources.v2.python.{PythonDataSourceV2, PythonMicroBatchStream, PythonStreamingSourceOffset} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class PythonStreamingDataSourceSuite extends PythonDataSourceSuiteBase { + + protected def simpleDataStreamReaderScript: String = + """ + |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition + | + |class SimpleDataStreamReader(DataSourceStreamReader): + | def initialOffset(self): + | return {"offset": {"partition-1": 0}} + | def latestOffset(self): + | return {"offset": {"partition-1": 2}} + | def partitions(self, start: dict, end: dict): + | start_index = start["offset"]["partition-1"] + | end_index = end["offset"]["partition-1"] + | return [InputPartition(i) for i in range(start_index, end_index)] + | def commit(self, end: dict): + | 1 + 2 + | def read(self, partition): + | yield (0, partition.value) + | yield (1, partition.value) + | yield (2, partition.value) + |""".stripMargin + + protected def errorDataStreamReaderScript: String = + """ + |from pyspark.sql.datasource import DataSourceStreamReader, InputPartition + | + |class ErrorDataStreamReader(DataSourceStreamReader): + | def initialOffset(self): + | raise Exception("error reading initial offset") + | def latestOffset(self): + | raise Exception("error reading latest offset") + | def partitions(self, start: dict, end: dict): + | raise Exception("error planning partitions") + | def commit(self, end: dict): + | raise Exception("error committing offset") + | def read(self, partition): + | yield (0, partition.value) + | yield (1, partition.value) + | yield (2, partition.value) + |""".stripMargin + + private val errorDataSourceName = "ErrorDataSource" + + test("simple data stream source") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |$simpleDataStreamReaderScript + | + |class $dataSourceName(DataSource): + | def streamReader(self, schema): + | return SimpleDataStreamReader() + |""".stripMargin + val inputSchema = StructType.fromDDL("input BINARY") + + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val pythonDs = new PythonDataSourceV2 + pythonDs.setShortName("SimpleDataSource") + val stream = new PythonMicroBatchStream( + pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + + val initialOffset = stream.initialOffset() + assert(initialOffset.json == "{\"offset\": {\"partition-1\": 0}}") + for (_ <- 1 to 50) { + val offset = stream.latestOffset() + assert(offset.json == "{\"offset\": {\"partition-1\": 2}}") + assert(stream.planInputPartitions(initialOffset, offset).size == 2) + stream.commit(offset) + } + stream.stop() + } + + test("Error creating stream reader") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |class $dataSourceName(DataSource): + | def streamReader(self, schema): + | raise Exception("error creating stream reader") + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource( + name = dataSourceName, pythonScript = dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + val pythonDs = new PythonDataSourceV2 + pythonDs.setShortName("SimpleDataSource") + val inputSchema = StructType.fromDDL("input BINARY") + val err = intercept[AnalysisException] { + new PythonMicroBatchStream( + pythonDs, dataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + } + assert(err.getErrorClass == "PYTHON_DATA_SOURCE_ERROR") + assert(err.getMessage.contains("error creating stream reader")) + } + + test("Method not implemented in stream reader") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |from pyspark.sql.datasource import DataSourceStreamReader + |class ErrorDataStreamReader(DataSourceStreamReader): + | def read(self, partition): + | yield (0, partition.value) + | + |class $errorDataSourceName(DataSource): + | def streamReader(self, schema): + | return ErrorDataStreamReader() + |""".stripMargin + val inputSchema = StructType.fromDDL("input BINARY") + + val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, dataSourceScript) + spark.dataSource.registerPython(errorDataSourceName, dataSource) + val pythonDs = new PythonDataSourceV2 + pythonDs.setShortName("ErrorDataSource") + + def testMicroBatchStreamError(action: String, msg: String) + (func: PythonMicroBatchStream => Unit): Unit = { + val stream = new PythonMicroBatchStream( + pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + val err = intercept[SparkException] { + func(stream) + } + checkErrorMatchPVals(err, + errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + parameters = Map( + "action" -> action, + "msg" -> "(.|\\n)*" + )) + assert(err.getMessage.contains(msg)) + assert(err.getMessage.contains("ErrorDataSource")) + stream.stop() + } + + testMicroBatchStreamError( + "initialOffset", "[NOT_IMPLEMENTED] initialOffset is not implemented") { + stream => stream.initialOffset() + } + + testMicroBatchStreamError( + "latestOffset", "[NOT_IMPLEMENTED] latestOffset is not implemented") { + stream => stream.latestOffset() + } + + val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}") + testMicroBatchStreamError( + "planPartitions", "[NOT_IMPLEMENTED] partitions is not implemented") { + stream => stream.planInputPartitions(offset, offset) + } + } + + test("Error in stream reader") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource + |$errorDataStreamReaderScript + | + |class $errorDataSourceName(DataSource): + | def streamReader(self, schema): + | return ErrorDataStreamReader() + |""".stripMargin + val inputSchema = StructType.fromDDL("input BINARY") + + val dataSource = createUserDefinedPythonDataSource(errorDataSourceName, dataSourceScript) + spark.dataSource.registerPython(errorDataSourceName, dataSource) + val pythonDs = new PythonDataSourceV2 + pythonDs.setShortName("ErrorDataSource") + val offset = PythonStreamingSourceOffset("{\"offset\": \"2\"}") + + def testMicroBatchStreamError(action: String, msg: String) + (func: PythonMicroBatchStream => Unit): Unit = { + val stream = new PythonMicroBatchStream( + pythonDs, errorDataSourceName, inputSchema, CaseInsensitiveStringMap.empty()) + val err = intercept[SparkException] { + func(stream) + } + checkErrorMatchPVals(err, + errorClass = "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR", + parameters = Map( + "action" -> action, + "msg" -> "(.|\\n)*" + )) + assert(err.getMessage.contains(msg)) + assert(err.getMessage.contains("ErrorDataSource")) + stream.stop() + } + + testMicroBatchStreamError("initialOffset", "error reading initial offset") { + stream => stream.initialOffset() + } + + testMicroBatchStreamError("latestOffset", "error reading latest offset") { + stream => stream.latestOffset() + } + + testMicroBatchStreamError("planPartitions", "error planning partitions") { + stream => stream.planInputPartitions(offset, offset) + } + + testMicroBatchStreamError("commitSource", "error committing offset") { + stream => stream.commit(offset) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org