This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 36ed5ee4a0d7 [SPARK-53429][PYTHON] Support Direct Passthrough
Partitioning in the PySpark Dataframe API
36ed5ee4a0d7 is described below
commit 36ed5ee4a0d7c0ea33140d1a4e9479e396efe139
Author: Shujing Yang <[email protected]>
AuthorDate: Mon Sep 22 14:13:37 2025 +0800
[SPARK-53429][PYTHON] Support Direct Passthrough Partitioning in the
PySpark Dataframe API
### What changes were proposed in this pull request?
This PR implements the repartitionById method for PySpark DataFrames
### Why are the changes needed?
Support Direct Passthrough Partitioning in the PySpark
### Does this PR introduce _any_ user-facing change?
Yes
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
Closes #52295 from shujingyang-db/direct-passthrough-pyspark-api.
Lead-authored-by: Shujing Yang <[email protected]>
Co-authored-by: Shujing Yang
<[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/classic/dataframe.py | 24 ++++++
python/pyspark/sql/connect/dataframe.py | 33 ++++++++
python/pyspark/sql/connect/expressions.py | 21 +++++
python/pyspark/sql/dataframe.py | 61 ++++++++++++++
python/pyspark/sql/tests/test_repartition.py | 118 ++++++++++++++++++++++++++-
5 files changed, 255 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/classic/dataframe.py
b/python/pyspark/sql/classic/dataframe.py
index 05ec586dc8f7..7774164f7d43 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -569,6 +569,30 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin,
PandasConversionMixin):
},
)
+ def repartitionById(
+ self, numPartitions: int, partitionIdCol: "ColumnOrName"
+ ) -> ParentDataFrame:
+ if not isinstance(numPartitions, (int, bool)):
+ raise PySparkTypeError(
+ errorClass="NOT_INT",
+ messageParameters={
+ "arg_name": "numPartitions",
+ "arg_type": type(numPartitions).__name__,
+ },
+ )
+ if numPartitions <= 0:
+ raise PySparkValueError(
+ errorClass="VALUE_NOT_POSITIVE",
+ messageParameters={
+ "arg_name": "numPartitions",
+ "arg_value": str(numPartitions),
+ },
+ )
+ return DataFrame(
+ self._jdf.repartitionById(numPartitions,
_to_java_column(partitionIdCol)),
+ self.sparkSession,
+ )
+
def distinct(self) -> ParentDataFrame:
return DataFrame(self._jdf.distinct(), self.sparkSession)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index aeafd8552dd0..ab7fdc90ba3c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -82,6 +82,7 @@ from pyspark.sql.connect.streaming.readwriter import
DataStreamWriter
from pyspark.sql.column import Column
from pyspark.sql.connect.expressions import (
ColumnReference,
+ DirectShufflePartitionID,
SubqueryExpression,
UnresolvedRegex,
UnresolvedStar,
@@ -443,6 +444,38 @@ class DataFrame(ParentDataFrame):
res._cached_schema = self._cached_schema
return res
+ def repartitionById(
+ self, numPartitions: int, partitionIdCol: "ColumnOrName"
+ ) -> ParentDataFrame:
+ from pyspark.sql.connect.column import Column as ConnectColumn
+
+ if not isinstance(numPartitions, int) or isinstance(numPartitions,
bool):
+ raise PySparkTypeError(
+ errorClass="NOT_INT",
+ messageParameters={
+ "arg_name": "numPartitions",
+ "arg_type": type(numPartitions).__name__,
+ },
+ )
+ if numPartitions <= 0:
+ raise PySparkValueError(
+ errorClass="VALUE_NOT_POSITIVE",
+ messageParameters={
+ "arg_name": "numPartitions",
+ "arg_value": str(numPartitions),
+ },
+ )
+
+ partition_connect_col = cast(ConnectColumn, F._to_col(partitionIdCol))
+ direct_partition_expr =
DirectShufflePartitionID(partition_connect_col._expr)
+ direct_partition_col = ConnectColumn(direct_partition_expr)
+ res = DataFrame(
+ plan.RepartitionByExpression(self._plan, numPartitions,
[direct_partition_col]),
+ self._session,
+ )
+ res._cached_schema = self._cached_schema
+ return res
+
def dropDuplicates(self, subset: Optional[List[str]] = None) ->
ParentDataFrame:
if subset is not None and not isinstance(subset, (list, tuple)):
raise PySparkTypeError(
diff --git a/python/pyspark/sql/connect/expressions.py
b/python/pyspark/sql/connect/expressions.py
index 624599aac9e3..b397aa2121cc 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1343,3 +1343,24 @@ class SubqueryExpression(Expression):
repr_parts.append(f"values={self._in_subquery_values}")
return f"SubqueryExpression({', '.join(repr_parts)})"
+
+
+class DirectShufflePartitionID(Expression):
+ """
+ Expression that takes a partition ID value and passes it through directly
for use in
+ shuffle partitioning. This is used with RepartitionByExpression to allow
users to
+ directly specify target partition IDs.
+ """
+
+ def __init__(self, child: Expression):
+ super().__init__()
+ assert child is not None and isinstance(child, Expression)
+ self._child = child
+
+ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+ expr = self._create_proto_expression()
+
expr.direct_shuffle_partition_id.child.CopyFrom(self._child.to_plan(session))
+ return expr
+
+ def __repr__(self) -> str:
+ return f"DirectShufflePartitionID(child={self._child})"
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3d8dc970ba43..ca33539df960 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1887,6 +1887,67 @@ class DataFrame:
"""
...
+ @dispatch_df_method
+ def repartitionById(self, numPartitions: int, partitionIdCol:
"ColumnOrName") -> "DataFrame":
+ """
+ Returns a new :class:`DataFrame` partitioned by the given partition ID
expression.
+ Each row's target partition is determined directly by the value of the
partition ID column.
+
+ .. versionadded:: 4.1.0
+
+ .. versionchanged:: 4.1.0
+ Supports Spark Connect.
+
+ Parameters
+ ----------
+ numPartitions : int
+ target number of partitions
+ partitionIdCol : str or :class:`Column`
+ column expression that evaluates to the target partition ID for
each row.
+ Must be an integer type. Values are taken modulo numPartitions to
determine
+ the final partition. Null values are sent to partition 0.
+
+ Returns
+ -------
+ :class:`DataFrame`
+ Repartitioned DataFrame.
+
+ Notes
+ -----
+ The partition ID expression must evaluate to an integer type.
+ Partition IDs are taken modulo numPartitions, so values outside the
range [0, numPartitions)
+ are automatically mapped to valid partition IDs. If the partition ID
expression evaluates to
+ a NULL value, the row is sent to partition 0.
+
+ This method provides direct control over partition placement, similar
to RDD's
+ partitionBy with custom partitioners, but at the DataFrame level.
+
+ Examples
+ --------
+ Partition rows based on a computed partition ID:
+
+ >>> from pyspark.sql import functions as sf
+ >>> from pyspark.sql.functions import col
+ >>> df = spark.range(10).withColumn("partition_id", (col("id") %
3).cast("int"))
+ >>> repartitioned = df.repartitionById(3, "partition_id")
+ >>> repartitioned.select("id", "partition_id",
sf.spark_partition_id()).orderBy("id").show()
+ +---+------------+--------------------+
+ | id|partition_id|SPARK_PARTITION_ID()|
+ +---+------------+--------------------+
+ | 0| 0| 0|
+ | 1| 1| 1|
+ | 2| 2| 2|
+ | 3| 0| 0|
+ | 4| 1| 1|
+ | 5| 2| 2|
+ | 6| 0| 0|
+ | 7| 1| 1|
+ | 8| 2| 2|
+ | 9| 0| 0|
+ +---+------------+--------------------+
+ """
+ ...
+
@dispatch_df_method
def distinct(self) -> "DataFrame":
"""Returns a new :class:`DataFrame` containing the distinct rows in
this :class:`DataFrame`.
diff --git a/python/pyspark/sql/tests/test_repartition.py
b/python/pyspark/sql/tests/test_repartition.py
index 058861e9c161..862f14cb50b8 100644
--- a/python/pyspark/sql/tests/test_repartition.py
+++ b/python/pyspark/sql/tests/test_repartition.py
@@ -17,7 +17,7 @@
import unittest
-from pyspark.sql.functions import spark_partition_id
+from pyspark.sql.functions import spark_partition_id, col, lit, when
from pyspark.sql.types import (
StringType,
IntegerType,
@@ -25,7 +25,7 @@ from pyspark.sql.types import (
StructType,
StructField,
)
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -84,6 +84,120 @@ class DataFrameRepartitionTestsMixin:
messageParameters={"arg_name": "numPartitions", "arg_type":
"list"},
)
+ def test_repartition_by_id(self):
+ # Test basic partition ID passthrough behavior
+ numPartitions = 10
+ df = self.spark.range(100).withColumn("expected_p_id", col("id") %
numPartitions)
+ repartitioned = df.repartitionById(numPartitions,
col("expected_p_id").cast("int"))
+ result = repartitioned.withColumn("actual_p_id", spark_partition_id())
+
+ # All rows should be in their expected partitions
+ self.assertEqual(result.filter(col("expected_p_id") !=
col("actual_p_id")).count(), 0)
+
+ def test_repartition_by_id_negative_values(self):
+ df = self.spark.range(10).toDF("id")
+ repartitioned = df.repartitionById(10, (col("id") - 5).cast("int"))
+ result = repartitioned.withColumn("actual_p_id",
spark_partition_id()).collect()
+
+ for row in result:
+ actualPartitionId = row["actual_p_id"]
+ id_val = row["id"]
+ expectedPartitionId = int((id_val - 5) % 10)
+ self.assertEqual(
+ actualPartitionId,
+ expectedPartitionId,
+ f"Row with id={id_val} should be in partition
{expectedPartitionId}, "
+ f"but was in partition {actualPartitionId}",
+ )
+
+ def test_repartition_by_id_null_values(self):
+ # Test that null partition ids go to partition 0
+ df = self.spark.range(10).toDF("id")
+ partitionExpr = when(col("id") < 5,
col("id")).otherwise(lit(None)).cast("int")
+ repartitioned = df.repartitionById(10, partitionExpr)
+ result = repartitioned.withColumn("actual_p_id",
spark_partition_id()).collect()
+
+ nullRows = [row for row in result if row["id"] >= 5]
+ self.assertTrue(len(nullRows) > 0, "Should have rows with null
partition expression")
+ for row in nullRows:
+ self.assertEqual(
+ row["actual_p_id"],
+ 0,
+ f"Row with null partition id should go to partition 0, "
+ f"but went to partition {row['actual_p_id']}",
+ )
+
+ nonNullRows = [row for row in result if row["id"] < 5]
+ for row in nonNullRows:
+ id_val = row["id"]
+ actualPartitionId = row["actual_p_id"]
+ expectedPartitionId = id_val % 10
+ self.assertEqual(
+ actualPartitionId,
+ expectedPartitionId,
+ f"Row with id={id_val} should be in partition
{expectedPartitionId}, "
+ f"but was in partition {actualPartitionId}",
+ )
+
+ def test_repartition_by_id_error_non_int_type(self):
+ # Test error for non-integer partition column type
+ df = self.spark.range(5).withColumn("s", lit("a"))
+ with self.assertRaises(Exception): # Should raise analysis exception
+ df.repartitionById(5, col("s")).collect()
+
+ def test_repartition_by_id_error_invalid_num_partitions(self):
+ df = self.spark.range(5)
+
+ with self.assertRaises(PySparkTypeError) as pe:
+ df.repartitionById("5", col("id").cast("int"))
+ self.check_error(
+ exception=pe.exception,
+ errorClass="NOT_INT",
+ messageParameters={"arg_name": "numPartitions", "arg_type": "str"},
+ )
+
+ with self.assertRaises(PySparkValueError) as pe:
+ df.repartitionById(0, col("id").cast("int"))
+ self.check_error(
+ exception=pe.exception,
+ errorClass="VALUE_NOT_POSITIVE",
+ messageParameters={"arg_name": "numPartitions", "arg_value": "0"},
+ )
+
+ # Test negative numPartitions
+ with self.assertRaises(PySparkValueError) as pe:
+ df.repartitionById(-1, col("id").cast("int"))
+ self.check_error(
+ exception=pe.exception,
+ errorClass="VALUE_NOT_POSITIVE",
+ messageParameters={"arg_name": "numPartitions", "arg_value": "-1"},
+ )
+
+ def test_repartition_by_id_out_of_range(self):
+ numPartitions = 10
+ df = self.spark.range(20).toDF("id")
+ repartitioned = df.repartitionById(numPartitions,
col("id").cast("int"))
+ result = repartitioned.collect()
+
+ self.assertEqual(len(result), 20)
+ # Skip RDD partition count check for Connect mode since RDD is not
available
+ try:
+ self.assertEqual(repartitioned.rdd.getNumPartitions(),
numPartitions)
+ except Exception:
+ # Connect mode doesn't support RDD operations, so we skip this
check
+ pass
+
+ def test_repartition_by_id_string_column_name(self):
+ numPartitions = 5
+ df = self.spark.range(25).withColumn(
+ "partition_id", (col("id") % numPartitions).cast("int")
+ )
+ repartitioned = df.repartitionById(numPartitions, "partition_id")
+ result = repartitioned.withColumn("actual_p_id", spark_partition_id())
+
+ mismatches = result.filter(col("partition_id") !=
col("actual_p_id")).count()
+ self.assertEqual(mismatches, 0)
+
class DataFrameRepartitionTests(
DataFrameRepartitionTestsMixin,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]