This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 36ed5ee4a0d7 [SPARK-53429][PYTHON] Support Direct Passthrough 
Partitioning in the PySpark Dataframe API
36ed5ee4a0d7 is described below

commit 36ed5ee4a0d7c0ea33140d1a4e9479e396efe139
Author: Shujing Yang <[email protected]>
AuthorDate: Mon Sep 22 14:13:37 2025 +0800

    [SPARK-53429][PYTHON] Support Direct Passthrough Partitioning in the 
PySpark Dataframe API
    
    ### What changes were proposed in this pull request?
    
    This PR implements the repartitionById method for PySpark DataFrames
    
    ### Why are the changes needed?
    
    Support Direct Passthrough Partitioning in the PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    New unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #52295 from shujingyang-db/direct-passthrough-pyspark-api.
    
    Lead-authored-by: Shujing Yang <[email protected]>
    Co-authored-by: Shujing Yang 
<[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/classic/dataframe.py      |  24 ++++++
 python/pyspark/sql/connect/dataframe.py      |  33 ++++++++
 python/pyspark/sql/connect/expressions.py    |  21 +++++
 python/pyspark/sql/dataframe.py              |  61 ++++++++++++++
 python/pyspark/sql/tests/test_repartition.py | 118 ++++++++++++++++++++++++++-
 5 files changed, 255 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/sql/classic/dataframe.py 
b/python/pyspark/sql/classic/dataframe.py
index 05ec586dc8f7..7774164f7d43 100644
--- a/python/pyspark/sql/classic/dataframe.py
+++ b/python/pyspark/sql/classic/dataframe.py
@@ -569,6 +569,30 @@ class DataFrame(ParentDataFrame, PandasMapOpsMixin, 
PandasConversionMixin):
                 },
             )
 
+    def repartitionById(
+        self, numPartitions: int, partitionIdCol: "ColumnOrName"
+    ) -> ParentDataFrame:
+        if not isinstance(numPartitions, (int, bool)):
+            raise PySparkTypeError(
+                errorClass="NOT_INT",
+                messageParameters={
+                    "arg_name": "numPartitions",
+                    "arg_type": type(numPartitions).__name__,
+                },
+            )
+        if numPartitions <= 0:
+            raise PySparkValueError(
+                errorClass="VALUE_NOT_POSITIVE",
+                messageParameters={
+                    "arg_name": "numPartitions",
+                    "arg_value": str(numPartitions),
+                },
+            )
+        return DataFrame(
+            self._jdf.repartitionById(numPartitions, 
_to_java_column(partitionIdCol)),
+            self.sparkSession,
+        )
+
     def distinct(self) -> ParentDataFrame:
         return DataFrame(self._jdf.distinct(), self.sparkSession)
 
diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index aeafd8552dd0..ab7fdc90ba3c 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -82,6 +82,7 @@ from pyspark.sql.connect.streaming.readwriter import 
DataStreamWriter
 from pyspark.sql.column import Column
 from pyspark.sql.connect.expressions import (
     ColumnReference,
+    DirectShufflePartitionID,
     SubqueryExpression,
     UnresolvedRegex,
     UnresolvedStar,
@@ -443,6 +444,38 @@ class DataFrame(ParentDataFrame):
         res._cached_schema = self._cached_schema
         return res
 
+    def repartitionById(
+        self, numPartitions: int, partitionIdCol: "ColumnOrName"
+    ) -> ParentDataFrame:
+        from pyspark.sql.connect.column import Column as ConnectColumn
+
+        if not isinstance(numPartitions, int) or isinstance(numPartitions, 
bool):
+            raise PySparkTypeError(
+                errorClass="NOT_INT",
+                messageParameters={
+                    "arg_name": "numPartitions",
+                    "arg_type": type(numPartitions).__name__,
+                },
+            )
+        if numPartitions <= 0:
+            raise PySparkValueError(
+                errorClass="VALUE_NOT_POSITIVE",
+                messageParameters={
+                    "arg_name": "numPartitions",
+                    "arg_value": str(numPartitions),
+                },
+            )
+
+        partition_connect_col = cast(ConnectColumn, F._to_col(partitionIdCol))
+        direct_partition_expr = 
DirectShufflePartitionID(partition_connect_col._expr)
+        direct_partition_col = ConnectColumn(direct_partition_expr)
+        res = DataFrame(
+            plan.RepartitionByExpression(self._plan, numPartitions, 
[direct_partition_col]),
+            self._session,
+        )
+        res._cached_schema = self._cached_schema
+        return res
+
     def dropDuplicates(self, subset: Optional[List[str]] = None) -> 
ParentDataFrame:
         if subset is not None and not isinstance(subset, (list, tuple)):
             raise PySparkTypeError(
diff --git a/python/pyspark/sql/connect/expressions.py 
b/python/pyspark/sql/connect/expressions.py
index 624599aac9e3..b397aa2121cc 100644
--- a/python/pyspark/sql/connect/expressions.py
+++ b/python/pyspark/sql/connect/expressions.py
@@ -1343,3 +1343,24 @@ class SubqueryExpression(Expression):
             repr_parts.append(f"values={self._in_subquery_values}")
 
         return f"SubqueryExpression({', '.join(repr_parts)})"
+
+
+class DirectShufflePartitionID(Expression):
+    """
+    Expression that takes a partition ID value and passes it through directly 
for use in
+    shuffle partitioning. This is used with RepartitionByExpression to allow 
users to
+    directly specify target partition IDs.
+    """
+
+    def __init__(self, child: Expression):
+        super().__init__()
+        assert child is not None and isinstance(child, Expression)
+        self._child = child
+
+    def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
+        expr = self._create_proto_expression()
+        
expr.direct_shuffle_partition_id.child.CopyFrom(self._child.to_plan(session))
+        return expr
+
+    def __repr__(self) -> str:
+        return f"DirectShufflePartitionID(child={self._child})"
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 3d8dc970ba43..ca33539df960 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1887,6 +1887,67 @@ class DataFrame:
         """
         ...
 
+    @dispatch_df_method
+    def repartitionById(self, numPartitions: int, partitionIdCol: 
"ColumnOrName") -> "DataFrame":
+        """
+        Returns a new :class:`DataFrame` partitioned by the given partition ID 
expression.
+        Each row's target partition is determined directly by the value of the 
partition ID column.
+
+        .. versionadded:: 4.1.0
+
+        .. versionchanged:: 4.1.0
+            Supports Spark Connect.
+
+        Parameters
+        ----------
+        numPartitions : int
+            target number of partitions
+        partitionIdCol : str or :class:`Column`
+            column expression that evaluates to the target partition ID for 
each row.
+            Must be an integer type. Values are taken modulo numPartitions to 
determine
+            the final partition. Null values are sent to partition 0.
+
+        Returns
+        -------
+        :class:`DataFrame`
+            Repartitioned DataFrame.
+
+        Notes
+        -----
+        The partition ID expression must evaluate to an integer type.
+        Partition IDs are taken modulo numPartitions, so values outside the 
range [0, numPartitions)
+        are automatically mapped to valid partition IDs. If the partition ID 
expression evaluates to
+        a NULL value, the row is sent to partition 0.
+
+        This method provides direct control over partition placement, similar 
to RDD's
+        partitionBy with custom partitioners, but at the DataFrame level.
+
+        Examples
+        --------
+        Partition rows based on a computed partition ID:
+
+        >>> from pyspark.sql import functions as sf
+        >>> from pyspark.sql.functions import col
+        >>> df = spark.range(10).withColumn("partition_id", (col("id") % 
3).cast("int"))
+        >>> repartitioned = df.repartitionById(3, "partition_id")
+        >>> repartitioned.select("id", "partition_id", 
sf.spark_partition_id()).orderBy("id").show()
+        +---+------------+--------------------+
+        | id|partition_id|SPARK_PARTITION_ID()|
+        +---+------------+--------------------+
+        |  0|           0|                   0|
+        |  1|           1|                   1|
+        |  2|           2|                   2|
+        |  3|           0|                   0|
+        |  4|           1|                   1|
+        |  5|           2|                   2|
+        |  6|           0|                   0|
+        |  7|           1|                   1|
+        |  8|           2|                   2|
+        |  9|           0|                   0|
+        +---+------------+--------------------+
+        """
+        ...
+
     @dispatch_df_method
     def distinct(self) -> "DataFrame":
         """Returns a new :class:`DataFrame` containing the distinct rows in 
this :class:`DataFrame`.
diff --git a/python/pyspark/sql/tests/test_repartition.py 
b/python/pyspark/sql/tests/test_repartition.py
index 058861e9c161..862f14cb50b8 100644
--- a/python/pyspark/sql/tests/test_repartition.py
+++ b/python/pyspark/sql/tests/test_repartition.py
@@ -17,7 +17,7 @@
 
 import unittest
 
-from pyspark.sql.functions import spark_partition_id
+from pyspark.sql.functions import spark_partition_id, col, lit, when
 from pyspark.sql.types import (
     StringType,
     IntegerType,
@@ -25,7 +25,7 @@ from pyspark.sql.types import (
     StructType,
     StructField,
 )
-from pyspark.errors import PySparkTypeError
+from pyspark.errors import PySparkTypeError, PySparkValueError
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -84,6 +84,120 @@ class DataFrameRepartitionTestsMixin:
             messageParameters={"arg_name": "numPartitions", "arg_type": 
"list"},
         )
 
+    def test_repartition_by_id(self):
+        # Test basic partition ID passthrough behavior
+        numPartitions = 10
+        df = self.spark.range(100).withColumn("expected_p_id", col("id") % 
numPartitions)
+        repartitioned = df.repartitionById(numPartitions, 
col("expected_p_id").cast("int"))
+        result = repartitioned.withColumn("actual_p_id", spark_partition_id())
+
+        # All rows should be in their expected partitions
+        self.assertEqual(result.filter(col("expected_p_id") != 
col("actual_p_id")).count(), 0)
+
+    def test_repartition_by_id_negative_values(self):
+        df = self.spark.range(10).toDF("id")
+        repartitioned = df.repartitionById(10, (col("id") - 5).cast("int"))
+        result = repartitioned.withColumn("actual_p_id", 
spark_partition_id()).collect()
+
+        for row in result:
+            actualPartitionId = row["actual_p_id"]
+            id_val = row["id"]
+            expectedPartitionId = int((id_val - 5) % 10)
+            self.assertEqual(
+                actualPartitionId,
+                expectedPartitionId,
+                f"Row with id={id_val} should be in partition 
{expectedPartitionId}, "
+                f"but was in partition {actualPartitionId}",
+            )
+
+    def test_repartition_by_id_null_values(self):
+        # Test that null partition ids go to partition 0
+        df = self.spark.range(10).toDF("id")
+        partitionExpr = when(col("id") < 5, 
col("id")).otherwise(lit(None)).cast("int")
+        repartitioned = df.repartitionById(10, partitionExpr)
+        result = repartitioned.withColumn("actual_p_id", 
spark_partition_id()).collect()
+
+        nullRows = [row for row in result if row["id"] >= 5]
+        self.assertTrue(len(nullRows) > 0, "Should have rows with null 
partition expression")
+        for row in nullRows:
+            self.assertEqual(
+                row["actual_p_id"],
+                0,
+                f"Row with null partition id should go to partition 0, "
+                f"but went to partition {row['actual_p_id']}",
+            )
+
+        nonNullRows = [row for row in result if row["id"] < 5]
+        for row in nonNullRows:
+            id_val = row["id"]
+            actualPartitionId = row["actual_p_id"]
+            expectedPartitionId = id_val % 10
+            self.assertEqual(
+                actualPartitionId,
+                expectedPartitionId,
+                f"Row with id={id_val} should be in partition 
{expectedPartitionId}, "
+                f"but was in partition {actualPartitionId}",
+            )
+
+    def test_repartition_by_id_error_non_int_type(self):
+        # Test error for non-integer partition column type
+        df = self.spark.range(5).withColumn("s", lit("a"))
+        with self.assertRaises(Exception):  # Should raise analysis exception
+            df.repartitionById(5, col("s")).collect()
+
+    def test_repartition_by_id_error_invalid_num_partitions(self):
+        df = self.spark.range(5)
+
+        with self.assertRaises(PySparkTypeError) as pe:
+            df.repartitionById("5", col("id").cast("int"))
+        self.check_error(
+            exception=pe.exception,
+            errorClass="NOT_INT",
+            messageParameters={"arg_name": "numPartitions", "arg_type": "str"},
+        )
+
+        with self.assertRaises(PySparkValueError) as pe:
+            df.repartitionById(0, col("id").cast("int"))
+        self.check_error(
+            exception=pe.exception,
+            errorClass="VALUE_NOT_POSITIVE",
+            messageParameters={"arg_name": "numPartitions", "arg_value": "0"},
+        )
+
+        # Test negative numPartitions
+        with self.assertRaises(PySparkValueError) as pe:
+            df.repartitionById(-1, col("id").cast("int"))
+        self.check_error(
+            exception=pe.exception,
+            errorClass="VALUE_NOT_POSITIVE",
+            messageParameters={"arg_name": "numPartitions", "arg_value": "-1"},
+        )
+
+    def test_repartition_by_id_out_of_range(self):
+        numPartitions = 10
+        df = self.spark.range(20).toDF("id")
+        repartitioned = df.repartitionById(numPartitions, 
col("id").cast("int"))
+        result = repartitioned.collect()
+
+        self.assertEqual(len(result), 20)
+        # Skip RDD partition count check for Connect mode since RDD is not 
available
+        try:
+            self.assertEqual(repartitioned.rdd.getNumPartitions(), 
numPartitions)
+        except Exception:
+            # Connect mode doesn't support RDD operations, so we skip this 
check
+            pass
+
+    def test_repartition_by_id_string_column_name(self):
+        numPartitions = 5
+        df = self.spark.range(25).withColumn(
+            "partition_id", (col("id") % numPartitions).cast("int")
+        )
+        repartitioned = df.repartitionById(numPartitions, "partition_id")
+        result = repartitioned.withColumn("actual_p_id", spark_partition_id())
+
+        mismatches = result.filter(col("partition_id") != 
col("actual_p_id")).count()
+        self.assertEqual(mismatches, 0)
+
 
 class DataFrameRepartitionTests(
     DataFrameRepartitionTestsMixin,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to