Bastian Beranek created SPARK-56635:
---------------------------------------
Summary: Python DataSource API: support declaring output
partitioning to avoid redundant write-side sort
Key: SPARK-56635
URL: https://issues.apache.org/jira/browse/SPARK-56635
Project: Spark
Issue Type: Improvement
Components: PySpark, SQL
Affects Versions: 4.1.1
Environment: Databricks Runtime 17.3 LTS
Reporter: Bastian Beranek
*Disclaimer*
The following was helped by Databricks Genie assistant. I can't vouch for the
correctness of the references to spark code. But it stems from debugging of a
real world problem, which I think could be solved as described below.
*Problem*
The Python DataSource API (pyspark.sql.datasource.DataSource /
DataSourceReader) has no way to declare the output partitioning of the data it
produces. As a result, Spark always assumes UnknownPartitioning, which triggers
two unnecessary operations when writing to a Hive-style partitioned table:
# A ShuffleExchangeExec (if EnsureRequirements determines the distribution is
unsatisfied)
# A SortExec on the partition columns (injected by the V1Writes rule)
The Java/Scala V2 DataSource API supports SupportsReportPartitioning, which
allows sources to declare their output partitioning so the planner can elide
the shuffle and sort. The Python DataSource API has no equivalent.
This is a concrete performance problem. For DataSources that already partition
data internally (e.g., each InputPartition maps to exactly one output
partition), the sort materializes all records through UnsafeExternalSorter and
spills to disk — even though every record in the partition carries the same
key. On memory-constrained clusters this causes executor OOM crashes.
*Reproduction*
{noformat}
%python
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
class MyReader(DataSourceReader):
def partitions(self):
# Each InputPartition contains data for exactly one (year, month) group
return [
InputPartition({"year": 2026, "month": 1, "files": [...]}),
InputPartition({"year": 2026, "month": 2, "files": [...]}),
# ...
]
def read(self, partition: InputPartition):
# All rows yielded here belong to a single (year, month) group
for row in process(partition.value):
yield row
class MySource(DataSource):
def reader(self, schema):
return MyReader(schema, self.options)
# This write adds an unnecessary Sort + potentially a Shuffle,
# because Spark cannot know the data is already partitioned.
spark.read.format("my_source").load() \
.write.partitionBy("year", "month") \
.format("delta").mode("append").save("/path/to/table")
{noformat}
*Proposal*
Add an optional method to DataSourceReader that lets Python DataSources declare
their output partitioning:
{noformat}
%python
class DataSourceReader:
def partitions(self) -> list[InputPartition]:
...
def outputPartitioning(self) -> PartitioningInfo | None:
"""Optional: declare how the output is partitioned.
If provided, the planner can skip the shuffle/sort when
the declared partitioning satisfies the write requirement.
Returns None (default) for UnknownPartitioning.
"""
return None{noformat}
A concrete DataSource would override it:
{noformat}
%python
class MyReader(DataSourceReader):
def outputPartitioning(self):
return ClusteredPartitioning(columns=["year", "month"]){noformat}
Under the hood, the Python DataSource runner would translate this into the
corresponding Partitioning object on the JVM side, the same way
SupportsReportPartitioning works for Java/Scala V2 sources. The
EnsureRequirements and V1Writes rules would then see the partitioning is
already satisfied and skip the exchange and sort.
*Scope*
* Extend pyspark.sql.datasource.DataSourceReader with an optional
outputPartitioning() method
* Add a Python-side ClusteredPartitioning (and optionally OrderedPartitioning)
class
* Wire the declared partitioning through to the JVM-side physical plan so
existing rules (EnsureRequirements, V1Writes) respect it
* Default behavior is unchanged (returns None → UnknownPartitioning)
*Related*
* SupportsReportPartitioning in the Java/Scala V2 DataSource API (already
supports this for JVM-based sources)
* Delta Lake write-side sort in TransactionalWrite
(https://github.com/delta-io/delta/issues/6676)
* SPARK-44076 (Python DataSource API initial implementation)
--
This message was sent by Atlassian Jira
(v8.20.10#820010)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]