This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4fcd5bfe003 [SPARK-45525][SQL][PYTHON] Support for Python data source write using DSv2 4fcd5bfe003 is described below commit 4fcd5bfe003bb546ca888efaf1d39c15c9685673 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Fri Dec 22 09:28:47 2023 +0800 [SPARK-45525][SQL][PYTHON] Support for Python data source write using DSv2 ### What changes were proposed in this pull request? This PR adds initial support for Python data source write by implementing the DSv2 `SupportsWrite` interface for `PythonTableProvider`. Note this PR only supports the `def write(self, iterator)` API. `commit` and `abort` will be supported in [SPARK-45914](https://issues.apache.org/jira/browse/SPARK-45914). ### Why are the changes needed? To support Python data source APIs. For instance: ```python class SimpleWriter(DataSourceWriter): def write(self, iterator: Iterator[Row]) -> WriterCommitMessage: for row in iterator: print(row) return WriterCommitMessage() class SimpleDataSource(DataSource): def writer(self, schema, overwrite): return SimpleWriter() # Regsiter the Python data source spark.dataSource.register(SimpleDataSource) df.range(10).write.format("SimpleDataSource").mode("append").save() ``` ### Does this PR introduce _any_ user-facing change? Yes, this PR supports writing data into a Python data source. ### How was this patch tested? New unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #43791 from allisonwang-db/spark-45525-data-source-write. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../src/main/resources/error/error-classes.json | 6 + docs/sql-error-conditions.md | 6 + python/pyspark/errors/error_classes.py | 5 + python/pyspark/sql/tests/test_python_datasource.py | 36 ++- .../pyspark/sql/worker/write_into_data_source.py | 233 ++++++++++++++++++ .../spark/sql/errors/QueryExecutionErrors.scala | 6 + .../python/UserDefinedPythonDataSource.scala | 269 +++++++++++++++++---- .../execution/python/PythonDataSourceSuite.scala | 95 ++++++++ 8 files changed, 612 insertions(+), 44 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index df223f3298e..8970045d4ab 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2513,6 +2513,12 @@ ], "sqlState" : "42601" }, + "INVALID_WRITER_COMMIT_MESSAGE" : { + "message" : [ + "The data source writer has generated an invalid number of commit messages. Expected exactly one writer commit message from each task, but received <detail>." + ], + "sqlState" : "42KDE" + }, "INVALID_WRITE_DISTRIBUTION" : { "message" : [ "The requested write distribution is invalid." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index a1af6863913..0722cae5815 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1398,6 +1398,12 @@ Rewrite the query to avoid window functions, aggregate functions, and generator Cannot specify ORDER BY or a window frame for `<aggFunc>`. +### INVALID_WRITER_COMMIT_MESSAGE + +[SQLSTATE: 42KDE](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) + +The data source writer has generated an invalid number of commit messages. Expected exactly one writer commit message from each task, but received `<detail>`. + ### [INVALID_WRITE_DISTRIBUTION](sql-error-conditions-invalid-write-distribution-error-class.html) [SQLSTATE: 42000](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index bb278481262..2200b73dffc 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -772,6 +772,11 @@ ERROR_CLASSES_JSON = """ "Expected <expected>, but got <actual>." ] }, + "PYTHON_DATA_SOURCE_WRITE_ERROR" : { + "message" : [ + "Unable to write to the Python data source: <error>." + ] + }, "PYTHON_HASH_SEED_NOT_SET" : { "message" : [ "Randomness of hash of string should be disabled via PYTHONHASHSEED." diff --git a/python/pyspark/sql/tests/test_python_datasource.py b/python/pyspark/sql/tests/test_python_datasource.py index 74ef6a87458..b1bba584d85 100644 --- a/python/pyspark/sql/tests/test_python_datasource.py +++ b/python/pyspark/sql/tests/test_python_datasource.py @@ -15,11 +15,18 @@ # limitations under the License. # import os +import tempfile import unittest from typing import Callable, Union from pyspark.errors import PythonException -from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition +from pyspark.sql.datasource import ( + DataSource, + DataSourceReader, + InputPartition, + DataSourceWriter, + WriterCommitMessage, +) from pyspark.sql.types import Row, StructType from pyspark.testing import assertDataFrameEqual from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -235,6 +242,17 @@ class BasePythonDataSourceTestsMixin: data = json.loads(line) yield data.get("name"), data.get("age") + class JsonDataSourceWriter(DataSourceWriter): + def __init__(self, options): + self.options = options + + def write(self, iterator): + path = self.options.get("path") + with open(path, "w") as file: + for row in iterator: + file.write(json.dumps(row.asDict()) + "\n") + return WriterCommitMessage() + class JsonDataSource(DataSource): @classmethod def name(cls): @@ -246,7 +264,11 @@ class BasePythonDataSourceTestsMixin: def reader(self, schema) -> "DataSourceReader": return JsonDataSourceReader(self.options) + def writer(self, schema, overwrite): + return JsonDataSourceWriter(self.options) + self.spark.dataSource.register(JsonDataSource) + # Test data source read. path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json") path2 = os.path.join(SPARK_HOME, "python/test_support/sql/people1.json") assertDataFrameEqual( @@ -257,6 +279,18 @@ class BasePythonDataSourceTestsMixin: self.spark.read.format("my-json").load(path2), [Row(name="Jonathan", age=None)], ) + # Test data source write. + df = self.spark.read.json(path1) + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "res.json") + df.write.format("my-json").mode("append").save(path) + with open(path, "r") as file: + text = file.read() + assert text == ( + '{"age": null, "name": "Michael"}\n' + '{"age": 30, "name": "Andy"}\n' + '{"age": 19, "name": "Justin"}\n' + ) class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase): diff --git a/python/pyspark/sql/worker/write_into_data_source.py b/python/pyspark/sql/worker/write_into_data_source.py new file mode 100644 index 00000000000..9c311dad033 --- /dev/null +++ b/python/pyspark/sql/worker/write_into_data_source.py @@ -0,0 +1,233 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os +import sys +from typing import IO, Iterable, Iterator + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.sql.connect.conversion import ArrowTableToRowsConversion +from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, PySparkTypeError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.serializers import ( + read_int, + write_int, + SpecialLengths, +) +from pyspark.sql import Row +from pyspark.sql.datasource import DataSource, WriterCommitMessage +from pyspark.sql.types import ( + _parse_datatype_json_string, + StructType, + BinaryType, + _create_row, +) +from pyspark.util import handle_worker_exception +from pyspark.worker_util import ( + check_python_version, + read_command, + pickleSer, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, + utf8_deserializer, +) + + +def main(infile: IO, outfile: IO) -> None: + """ + Main method for saving into a Python data source. + + This process is invoked from the `SaveIntoPythonDataSourceRunner.runInPython` method + in the optimizer rule `PythonDataSourceWrites` in JVM. This process is responsible for + creating a `DataSource` object and a DataSourceWriter instance, and send information + needed back to the JVM. + + The JVM sends the following information to this process: + - a `DataSource` class representing the data source to be created. + - a provider name in string. + - a schema in json string. + - a dictionary of options in string. + + This process first creates a `DataSource` instance and then a `DataSourceWriter` + instance and send a function using the writer instance that can be used + in mapInPandas/mapInArrow back to the JVM. + """ + try: + check_python_version(infile) + + memory_limit_mb = int(os.environ.get("PYSPARK_PLANNER_MEMORY_MB", "-1")) + setup_memory_limits(memory_limit_mb) + + setup_spark_files(infile) + setup_broadcasts(infile) + + _accumulatorRegistry.clear() + + # Receive the data source class. + data_source_cls = read_command(pickleSer, infile) + if not (isinstance(data_source_cls, type) and issubclass(data_source_cls, DataSource)): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "a subclass of DataSource", + "actual": f"'{type(data_source_cls).__name__}'", + }, + ) + + # Check the name method is a class method. + if not inspect.ismethod(data_source_cls.name): + raise PySparkTypeError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "'name()' method to be a classmethod", + "actual": f"'{type(data_source_cls.name).__name__}'", + }, + ) + + # Receive the provider name. + provider = utf8_deserializer.loads(infile) + + # Check if the provider name matches the data source's name. + if provider.lower() != data_source_cls.name().lower(): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": f"provider with name {data_source_cls.name()}", + "actual": f"'{provider}'", + }, + ) + + # Receive the input schema + schema = _parse_datatype_json_string(utf8_deserializer.loads(infile)) + if not isinstance(schema, StructType): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "the schema to be a 'StructType'", + "actual": f"'{type(data_source_cls).__name__}'", + }, + ) + + # Receive the return type + return_type = _parse_datatype_json_string(utf8_deserializer.loads(infile)) + if not isinstance(return_type, StructType): + raise PySparkAssertionError( + error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH", + message_parameters={ + "expected": "a return type of type 'StructType'", + "actual": f"'{type(return_type).__name__}'", + }, + ) + assert len(return_type) == 1 and isinstance(return_type[0].dataType, BinaryType), ( + "The output schema of Python data source write should contain only one column of type " + f"'BinaryType', but got '{return_type}'" + ) + return_col_name = return_type[0].name + + # Receive the options. + options = dict() + num_options = read_int(infile) + for _ in range(num_options): + key = utf8_deserializer.loads(infile) + value = utf8_deserializer.loads(infile) + options[key] = value + + # Receive the save mode. + save_mode = utf8_deserializer.loads(infile) + + # Instantiate a data source. + try: + data_source = data_source_cls(options=options) + except Exception as e: + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", + message_parameters={"type": "instance", "error": str(e)}, + ) + + # Instantiate the data source writer. + try: + writer = data_source.writer(schema, save_mode) + except Exception as e: + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_CREATE_ERROR", + message_parameters={"type": "writer", "error": str(e)}, + ) + + # Create a function that can be used in mapInArrow. + import pyarrow as pa + + converters = [ + ArrowTableToRowsConversion._create_converter(f.dataType) for f in schema.fields + ] + fields = schema.fieldNames() + + def data_source_write_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]: + def batch_to_rows() -> Iterator[Row]: + for batch in iterator: + columns = [column.to_pylist() for column in batch.columns] + for row in range(0, batch.num_rows): + values = [ + converters[col](columns[col][row]) for col in range(batch.num_columns) + ] + yield _create_row(fields=fields, values=values) + + res = writer.write(batch_to_rows()) + + # Check the commit message has the right type. + if not isinstance(res, WriterCommitMessage): + raise PySparkRuntimeError( + error_class="PYTHON_DATA_SOURCE_WRITE_ERROR", + message_parameters={ + "error": f"return type of the `write` method must be " + f"an instance of WriterCommitMessage, but got {type(res)}" + }, + ) + + # Serialize the commit message and return it. + pickled = pickleSer.dumps(res) + + # Return the commit message. + messages = pa.array([pickled]) + yield pa.record_batch([messages], names=[return_col_name]) + + # Return the pickled write UDF. + command = (data_source_write_func, return_type) + pickleSer._write_with_length(command, outfile) + + except BaseException as e: + handle_worker_exception(e, outfile) + sys.exit(-1) + + send_accumulator_updates(outfile) + + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + sys.exit(-1) + + +if __name__ == "__main__": + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) + main(sock_file, sock_file) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 113f995968a..b0eaf84fe6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -2771,4 +2771,10 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "parameter" -> toSQLId("charset"), "charset" -> charset)) } + + def invalidWriterCommitMessageError(details: String): Throwable = { + new SparkRuntimeException( + errorClass = "INVALID_WRITER_COMMIT_MESSAGE", + messageParameters = Map("details" -> details)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala index d31b3135d65..00974a7e297 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala @@ -24,19 +24,20 @@ import scala.jdk.CollectionConverters._ import net.razorvine.pickle.Pickler -import org.apache.spark.JobArtifactSet +import org.apache.spark.{JobArtifactSet, SparkException} import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths} -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{SaveMode, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.PythonUDF import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} -import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ +import org.apache.spark.sql.connector.catalog.{SupportsRead, SupportsWrite, Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.TableCapability.{BATCH_READ, BATCH_WRITE} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric} import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, LogicalWriteInfo, PhysicalWriteInfo, Write, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BinaryType, DataType, StructType} @@ -72,11 +73,11 @@ class PythonTableProvider extends TableProvider { partitioning: Array[Transform], properties: java.util.Map[String, String]): Table = { val outputSchema = schema - new Table with SupportsRead { + new Table with SupportsRead with SupportsWrite { override def name(): String = shortName override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of( - BATCH_READ) + BATCH_READ, BATCH_WRITE) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new ScanBuilder with Batch with Scan { @@ -103,6 +104,7 @@ class PythonTableProvider extends TableProvider { new PythonPartitionReaderFactory( source, readerFunc, outputSchema, jobArtifactUUID) } + override def description: String = "(Python)" override def supportedCustomMetrics(): Array[CustomMetric] = @@ -111,6 +113,38 @@ class PythonTableProvider extends TableProvider { } override def schema(): StructType = outputSchema + + override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = { + new WriteBuilder { + override def build(): Write = new Write { + + override def toBatch: BatchWrite = new BatchWrite { + + override def createBatchWriterFactory( + physicalInfo: PhysicalWriteInfo): DataWriterFactory = { + + val writeInfo = source.createWriteInfoInPython( + shortName, + info.schema(), + info.options(), + SaveMode.Append) + PythonBatchWriterFactory(source, writeInfo.func, info.schema(), jobArtifactUUID) + } + + // TODO(SPARK-45914): Support commit protocol + override def commit(messages: Array[WriterCommitMessage]): Unit = {} + + // TODO(SPARK-45914): Support commit protocol + override def abort(messages: Array[WriterCommitMessage]): Unit = {} + } + + override def description: String = "(Python)" + + override def supportedCustomMetrics(): Array[CustomMetric] = + source.createPythonMetrics() + } + } + } } } @@ -124,27 +158,26 @@ class PythonPartitionReaderFactory( pickledReadFunc: Array[Byte], outputSchema: StructType, jobArtifactUUID: Option[String]) - extends PartitionReaderFactory { + extends PartitionReaderFactory with PythonDataSourceSQLMetrics { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { new PartitionReader[InternalRow] { - // Dummy SQLMetrics. The result is manually reported via DSv2 interface - // via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc` - // is not used when it is reported. It is to reuse existing Python runner. - // See also `UserDefinedPythonDataSource.createPythonMetrics`. - private[this] val metrics: Map[String, SQLMetric] = { - PythonSQLMetrics.pythonSizeMetricsDesc.keys - .map(_ -> new SQLMetric("size", -1)).toMap ++ - PythonSQLMetrics.pythonOtherMetricsDesc.keys - .map(_ -> new SQLMetric("sum", -1)).toMap - } - private val outputIter = source.createPartitionReadIteratorInPython( - partition.asInstanceOf[PythonInputPartition], - pickledReadFunc, - outputSchema, - metrics, - jobArtifactUUID) + private[this] val metrics: Map[String, SQLMetric] = pythonMetrics + + private val outputIter = { + val evaluatorFactory = source.createMapInBatchEvaluatorFactory( + pickledReadFunc, + "read_from_data_source", + UserDefinedPythonDataSource.readInputSchema, + outputSchema, + metrics, + jobArtifactUUID) + + val part = partition.asInstanceOf[PythonInputPartition] + evaluatorFactory.createEvaluator().eval( + part.index, Iterator.single(InternalRow(part.pickedPartition))) + } override def next(): Boolean = outputIter.hasNext @@ -159,9 +192,75 @@ class PythonPartitionReaderFactory( } } +case class PythonWriterCommitMessage(pickledMessage: Array[Byte]) extends WriterCommitMessage + +private case class PythonBatchWriterFactory( + source: UserDefinedPythonDataSource, + pickledWriteFunc: Array[Byte], + inputSchema: StructType, + jobArtifactUUID: Option[String]) extends DataWriterFactory with PythonDataSourceSQLMetrics { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new DataWriter[InternalRow] { + + private[this] val metrics: Map[String, SQLMetric] = pythonMetrics + + private var commitMessage: PythonWriterCommitMessage = _ + + override def writeAll(records: java.util.Iterator[InternalRow]): Unit = { + val evaluatorFactory = source.createMapInBatchEvaluatorFactory( + pickledWriteFunc, + "write_to_data_source", + inputSchema, + UserDefinedPythonDataSource.writeOutputSchema, + metrics, + jobArtifactUUID) + val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, records.asScala) + outputIter.foreach { row => + if (commitMessage == null) { + commitMessage = PythonWriterCommitMessage(row.getBinary(0)) + } else { + throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "more than one") + } + } + if (commitMessage == null) { + throw QueryExecutionErrors.invalidWriterCommitMessageError(details = "zero") + } + } + + override def write(record: InternalRow): Unit = + SparkException.internalError("write method for Python data source should not be called.") + + override def commit(): WriterCommitMessage = { + commitMessage.asInstanceOf[WriterCommitMessage] + } + + override def abort(): Unit = {} + + override def close(): Unit = {} + + override def currentMetricsValues(): Array[CustomTaskMetric] = { + source.createPythonTaskMetrics(metrics.map { case (k, v) => k -> v.value }) + } + } + } +} + +trait PythonDataSourceSQLMetrics { + // Dummy SQLMetrics. The result is manually reported via DSv2 interface + // via passing the value to `CustomTaskMetric`. Note that `pythonOtherMetricsDesc` + // is not used when it is reported. It is to reuse existing Python runner. + // See also `UserDefinedPythonDataSource.createPythonMetrics`. + protected lazy val pythonMetrics: Map[String, SQLMetric] = { + PythonSQLMetrics.pythonSizeMetricsDesc.keys + .map(_ -> new SQLMetric("size", -1)).toMap ++ + PythonSQLMetrics.pythonOtherMetricsDesc.keys + .map(_ -> new SQLMetric("sum", -1)).toMap + } +} + class PythonCustomMetric( - override val name: String, - override val description: String) extends CustomMetric { + override val name: String, + override val description: String) extends CustomMetric { // To allow the aggregation can be called. See `SQLAppStatusListener.aggregateMetrics` def this() = this(null, null) @@ -182,8 +281,6 @@ class PythonCustomTaskMetric( */ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { - private val inputSchema: StructType = new StructType().add("partition", BinaryType) - /** * (Driver-side) Run Python process, and get the pickled Python Data Source * instance and its schema. @@ -207,26 +304,44 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { pythonResult: PythonDataSourceCreationResult, outputSchema: StructType): PythonDataSourceReadInfo = { new UserDefinedPythonDataSourceReadRunner( - createPythonFunction( - pythonResult.dataSource), inputSchema, outputSchema).runInPython() + createPythonFunction(pythonResult.dataSource), + UserDefinedPythonDataSource.readInputSchema, + outputSchema).runInPython() + } + + /** + * (Driver-side) Run Python process and get pickled write function. + */ + def createWriteInfoInPython( + provider: String, + inputSchema: StructType, + options: CaseInsensitiveStringMap, + mode: SaveMode): PythonDataSourceWriteInfo = { + new UserDefinedPythonDataSourceWriteRunner( + dataSourceCls, + provider, + inputSchema, + options.asCaseSensitiveMap().asScala.toMap, + mode).runInPython() } /** - * (Executor-side) Create an iterator that reads the input partitions. + * (Executor-side) Create an iterator that execute the Python function. */ - def createPartitionReadIteratorInPython( - partition: PythonInputPartition, - pickledReadFunc: Array[Byte], + def createMapInBatchEvaluatorFactory( + pickledFunc: Array[Byte], + funcName: String, + inputSchema: StructType, outputSchema: StructType, metrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]): Iterator[InternalRow] = { - val readerFunc = createPythonFunction(pickledReadFunc) + jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = { + val pythonFunc = createPythonFunction(pickledFunc) val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF val pythonUDF = PythonUDF( - name = "read_from_data_source", - func = readerFunc, + name = funcName, + func = pythonFunc, dataType = outputSchema, children = toAttributes(inputSchema), evalType = pythonEvalType, @@ -235,7 +350,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { val conf = SQLConf.get val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - val evaluatorFactory = new MapInBatchEvaluatorFactory( + new MapInBatchEvaluatorFactory( toAttributes(outputSchema), Seq(ChainedPythonFunctions(Seq(pythonUDF.func))), inputSchema, @@ -246,10 +361,6 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { pythonRunnerConf, metrics, jobArtifactUUID) - - val part = partition - evaluatorFactory.createEvaluator().eval( - part.index, Iterator.single(InternalRow(part.pickedPartition))) } def createPythonMetrics(): Array[CustomMetric] = { @@ -275,6 +386,18 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) { } } +object UserDefinedPythonDataSource { + /** + * The schema of the input to the Python data source read function. + */ + val readInputSchema: StructType = new StructType().add("partition", BinaryType) + + /** + * The schema of the output to the Python data source write function. + */ + val writeOutputSchema: StructType = new StructType().add("message", BinaryType) +} + /** * Used to store the result of creating a Python data source in the Python process. */ @@ -402,3 +525,63 @@ class UserDefinedPythonDataSourceReadRunner( partitions = pickledPartitions.toSeq) } } + +/** + * Hold the results of running [[UserDefinedPythonDataSourceWriteRunner]]. + */ +case class PythonDataSourceWriteInfo(func: Array[Byte]) + +/** + * A runner that creates a Python data source writer instance and returns a Python function + * to be used to write data into the data source. + */ +class UserDefinedPythonDataSourceWriteRunner( + dataSourceCls: PythonFunction, + provider: String, + inputSchema: StructType, + options: Map[String, String], + mode: SaveMode) extends PythonPlannerRunner[PythonDataSourceWriteInfo](dataSourceCls) { + + override val workerModule: String = "pyspark.sql.worker.write_into_data_source" + + override protected def writeToPython(dataOut: DataOutputStream, pickler: Pickler): Unit = { + // Send the Python data source class. + PythonWorkerUtils.writePythonFunction(dataSourceCls, dataOut) + + // Send the provider name + PythonWorkerUtils.writeUTF(provider, dataOut) + + // Send the input schema + PythonWorkerUtils.writeUTF(inputSchema.json, dataOut) + + // Send the return type + PythonWorkerUtils.writeUTF(UserDefinedPythonDataSource.writeOutputSchema.json, dataOut) + + // Send the options + dataOut.writeInt(options.size) + options.iterator.foreach { case (key, value) => + PythonWorkerUtils.writeUTF(key, dataOut) + PythonWorkerUtils.writeUTF(value, dataOut) + } + + // Send the mode + PythonWorkerUtils.writeUTF(mode.toString, dataOut) + } + + override protected def receiveFromPython( + dataIn: DataInputStream): PythonDataSourceWriteInfo = { + + // Receive the picked UDF or an exception raised in Python worker. + val length = dataIn.readInt() + if (length == SpecialLengths.PYTHON_EXCEPTION_THROWN) { + val msg = PythonWorkerUtils.readUTF(dataIn) + throw QueryCompilationErrors.failToPlanDataSourceError( + action = "plan", tpe = "write", msg = msg) + } + + // Receive the pickled data source. + val writeUdf: Array[Byte] = PythonWorkerUtils.readBytes(length, dataIn) + + PythonDataSourceWriteInfo(func = writeUdf) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala index e8a46449ac2..b04569ae554 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.python +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, QueryTest, Row} import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2ScanRelation} import org.apache.spark.sql.test.SharedSparkSession @@ -25,6 +26,8 @@ import org.apache.spark.sql.types.StructType class PythonDataSourceSuite extends QueryTest with SharedSparkSession { import IntegratedUDFTestUtils._ + setupTestData() + private def dataSourceName = "SimpleDataSource" private def simpleDataSourceReaderScript: String = """ @@ -453,4 +456,96 @@ class PythonDataSourceSuite extends QueryTest with SharedSparkSession { assert(metrics.contains(pythonDataReceived.id)) assert(metrics(pythonDataReceived.id).asInstanceOf[String].endsWith("B")) } + + test("simple data source write") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |import json + |from pyspark import TaskContext + |from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def __init__(self, options): + | self.options = options + | + | def write(self, iterator): + | context = TaskContext.get() + | partition_id = context.partitionId() + | path = self.options.get("path") + | assert path is not None + | output_path = f"{path}/{partition_id}.json" + | cnt = 0 + | with open(output_path, "w") as file: + | for row in iterator: + | file.write(json.dumps(row.asDict()) + "\\n") + | cnt += 1 + | return WriterCommitMessage() + | + |class SimpleDataSource(DataSource): + | def writer(self, schema, saveMode): + | return SimpleDataSourceWriter(self.options) + |""".stripMargin + val dataSource = createUserDefinedPythonDataSource(dataSourceName, dataSourceScript) + spark.dataSource.registerPython(dataSourceName, dataSource) + Seq( + "SELECT * FROM range(0, 5, 1, 3)", + "SELECT * FROM testData LIMIT 5", + "SELECT * FROM testData3", + "SELECT * FROM arrayData" + ).foreach { query => + withTempDir { dir => + val df = sql(query) + val path = dir.getAbsolutePath + df.write.format(dataSourceName).mode("append").save(path) + val df2 = spark.read.json(path) + checkAnswer(df, df2) + } + } + } + + test("data source write - error cases") { + assume(shouldTestPandasUDFs) + val dataSourceScript = + s""" + |from pyspark.sql.datasource import DataSource, DataSourceWriter + | + |class SimpleDataSourceWriter(DataSourceWriter): + | def write(self, iterator): + | num_rows = 0 + | for row in iterator: + | num_rows += 1 + | if num_rows > 2: + | raise Exception("something is wrong") + | + |class SimpleDataSource(DataSource): + | def writer(self, schema, saveMode): + | return SimpleDataSourceWriter() + |""".stripMargin + spark.dataSource.registerPython(dataSourceName, + createUserDefinedPythonDataSource(dataSourceName, dataSourceScript)) + + withClue("user error") { + val error = intercept[SparkException] { + spark.range(10).write.format(dataSourceName).mode("append").save() + } + assert(error.getMessage.contains("something is wrong")) + } + + withClue("no commit message") { + val error = intercept[SparkException] { + spark.range(1).write.format(dataSourceName).mode("append").save() + } + assert(error.getMessage.contains("PYTHON_DATA_SOURCE_WRITE_ERROR")) + } + + withClue("without mode") { + val error = intercept[AnalysisException] { + spark.range(1).write.format(dataSourceName).save() + } + // TODO: improve this error message. + assert(error.getMessage.contains("TableProvider implementation SimpleDataSource " + + "cannot be written with ErrorIfExists mode, please use Append or Overwrite modes instead.")) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org