ueshin commented on code in PR #52317:
URL: https://github.com/apache/spark/pull/52317#discussion_r2345583649


##########
python/pyspark/worker.py:
##########
@@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
             else:
                 return arg
 
+    class ArrowUDTFWithPartition:
+        """
+        Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a 
TABLE argument
+        with one or more PARTITION BY expressions.
+
+        Arrow UDTFs receive data as PyArrow RecordBatch objects instead of 
individual Row
+        objects.
+
+        Example table:
+            CREATE TABLE t (c1 INT, c2 INT) USING delta;
+
+        Example queries:
+            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
+            partition_child_indexes: 0, 1.
+
+            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
+            partition_child_indexes: 0, 2 (adds a projection for "c2 + 4").
+        """
+
+        def __init__(self, create_udtf: Callable, partition_child_indexes: 
list):
+            """
+            Create a new instance that wraps the provided Arrow UDTF with 
partitioning
+            logic.
+
+            Parameters
+            ----------
+            create_udtf: function
+                Function that creates a new instance of the Arrow UDTF to 
invoke.
+            partition_child_indexes: list
+                Zero-based indexes of input-table columns that contain 
projected
+                partitioning expressions.
+            """
+            self._create_udtf: Callable = create_udtf
+            self._udtf = create_udtf()
+            self._partition_child_indexes: list = partition_child_indexes
+            # Track last partition key from previous batch
+            self._last_partition_key: Optional[Tuple[Any, ...]] = None
+            self._eval_raised_skip_rest_of_input_table: bool = False
+
+        def eval(self, *args, **kwargs) -> Iterator:
+            """Handle partitioning logic for Arrow UDTFs that receive 
RecordBatch objects."""
+            import pyarrow as pa
+
+            # Get the original batch with partition columns
+            original_batch = self._get_table_arg(list(args) + 
list(kwargs.values()))
+            if not isinstance(original_batch, pa.RecordBatch):
+                # Arrow UDTFs with PARTITION BY must have a TABLE argument that
+                # results in a PyArrow RecordBatch
+                raise PySparkRuntimeError(
+                    errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT",
+                    messageParameters={
+                        "actual_type": str(type(original_batch))
+                        if original_batch is not None
+                        else "None"
+                    },
+                )
+
+            # Remove partition columns to get the filtered arguments
+            filtered_args = [self._remove_partition_by_exprs(arg) for arg in 
args]
+            filtered_kwargs = {
+                key: self._remove_partition_by_exprs(value) for (key, value) 
in kwargs.items()
+            }
+
+            # Get the filtered RecordBatch (without partition columns)
+            filtered_batch = self._get_table_arg(filtered_args + 
list(filtered_kwargs.values()))
+
+            # Process the RecordBatch by partitions
+            yield from self._process_arrow_batch_by_partitions(
+                original_batch, filtered_batch, filtered_args, filtered_kwargs
+            )
+
+        def _process_arrow_batch_by_partitions(
+            self, original_batch, filtered_batch, filtered_args, 
filtered_kwargs
+        ) -> Iterator:
+            """Process an Arrow RecordBatch by splitting it into partitions.
+
+            Since Catalyst guarantees that rows with the same partition key 
are contiguous,
+            we can use efficient boundary detection instead of group_by.
+
+            Handles two scenarios:
+            1. Multiple partitions within a single RecordBatch (using boundary 
detection)
+            2. Same partition key continuing from previous RecordBatch 
(tracking state)
+            """
+            import pyarrow as pa
+
+            if self._partition_child_indexes:
+                # Detect partition boundaries.
+                boundaries = self._detect_partition_boundaries(original_batch)
+
+                # Process each contiguous partition
+                for i in range(len(boundaries) - 1):
+                    start_idx = boundaries[i]
+                    end_idx = boundaries[i + 1]
+
+                    # Get the partition key for this segment
+                    partition_key = tuple(
+                        original_batch.column(idx)[start_idx].as_py()
+                        for idx in self._partition_child_indexes
+                    )
+
+                    # Check if this is a continuation of the previous batch's 
partition
+                    is_new_partition = (
+                        self._last_partition_key is not None
+                        and partition_key != self._last_partition_key
+                    )
+
+                    if is_new_partition:
+                        # Previous partition ended, call terminate
+                        if hasattr(self._udtf, "terminate"):
+                            terminate_result = self._udtf.terminate()
+                            if terminate_result is not None:
+                                for table in terminate_result:
+                                    yield table
+                        # Create new UDTF instance for new partition
+                        self._udtf = self._create_udtf()
+                        self._eval_raised_skip_rest_of_input_table = False
+
+                    # Slice the filtered batch for this partition
+                    partition_batch = filtered_batch.slice(start_idx, end_idx 
- start_idx)
+
+                    # Update the last partition key
+                    self._last_partition_key = partition_key
+
+                    # Update filtered args to use the partition batch
+                    partition_filtered_args = []
+                    for arg in filtered_args:
+                        if isinstance(arg, pa.RecordBatch):
+                            partition_filtered_args.append(partition_batch)
+                        else:
+                            partition_filtered_args.append(arg)
+
+                    partition_filtered_kwargs = {}
+                    for key, value in filtered_kwargs.items():
+                        if isinstance(value, pa.RecordBatch):
+                            partition_filtered_kwargs[key] = partition_batch
+                        else:
+                            partition_filtered_kwargs[key] = value
+
+                    # Call the UDTF with this partition's data
+                    if not self._eval_raised_skip_rest_of_input_table:
+                        try:
+                            result = self._udtf.eval(
+                                *partition_filtered_args, 
**partition_filtered_kwargs
+                            )
+                            if result is not None:
+                                for table in result:
+                                    yield table

Review Comment:
   nit
   ```suggestion
                                   yield from result
   ```



##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> 
Iterator["pa.Table"]:
         expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, 
sum int")
         assertDataFrameEqual(sql_result_df2, expected_df2)
 
+    def test_arrow_udtf_with_partition_by(self):
+        @arrow_udtf(returnType="partition_key int, sum_value int")
+        class SumUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                partition_key = pc.unique(table["partition_key"]).to_pylist()
+                assert (
+                    len(partition_key) == 1
+                ), f"Expected exactly one partition key, got {partition_key}"
+                sum_value = pc.sum(table["value"]).as_py()
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([partition_key[0]], 
type=pa.int32()),
+                        "sum_value": pa.array([sum_value], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 10),
+            (2, 5),
+            (1, 20),
+            (2, 15),
+            (1, 30),
+            (3, 100),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("sum_udtf", SumUDTF)
+        input_df.createOrReplaceTempView("test_data")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key)
+        """
+        )
+
+        expected_data = [
+            (1, 60),
+            (2, 20),
+            (3, 100),
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "partition_key 
int, sum_value int")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_terminate(self):
+        @arrow_udtf(returnType="partition_key int, count int, sum_value int")
+        class TerminateUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                # Track partition key
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Accumulate stats but don't yield here
+                self._count += table.num_rows
+                self._sum += pc.sum(table["value"]).as_py()
+                # Return empty iterator - results come from terminate
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                # Yield accumulated results for this partition
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "count": pa.array([self._count], type=pa.int32()),
+                            "sum_value": pa.array([self._sum], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (3, 50),
+            (1, 10),
+            (2, 40),
+            (1, 20),
+            (2, 30),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("terminate_udtf", TerminateUDTF)
+        input_df.createOrReplaceTempView("test_data_terminate")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM terminate_udtf(TABLE(test_data_terminate) PARTITION 
BY partition_key)
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 2, 30),  # partition 1: 2 rows, sum = 30
+            (2, 2, 70),  # partition 2: 2 rows, sum = 70
+            (3, 1, 50),  # partition 3: 1 row, sum = 50
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count int, sum_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_and_order_by(self):
+        @arrow_udtf(returnType="partition_key int, first_value int, last_value 
int")
+        class OrderByUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._first_value = None
+                self._last_value = None
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Track first and last values (should be ordered)
+                values = table["value"].to_pylist()
+                if values:
+                    if self._first_value is None:
+                        self._first_value = values[0]
+                    self._last_value = values[-1]
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "first_value": pa.array([self._first_value], 
type=pa.int32()),
+                            "last_value": pa.array([self._last_value], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        test_data = [
+            (1, 30),
+            (1, 10),
+            (1, 20),
+            (2, 60),
+            (2, 40),
+            (2, 50),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("order_by_udtf", OrderByUDTF)
+        input_df.createOrReplaceTempView("test_data_order")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM order_by_udtf(
+                TABLE(test_data_order)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key
+            """
+        )
+
+        expected_data = [
+            (1, 10, 30),  # partition 1: first=10 (min), last=30 (max) after 
ordering
+            (2, 40, 60),  # partition 2: first=40 (min), last=60 (max) after 
ordering
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, first_value int, last_value int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_partition_column_removal(self):
+        @arrow_udtf(returnType="col1_sum int, col2_sum int")
+        class PartitionColumnTestUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+
+                # When partitioning by an expression like "col1 + col2",
+                # Catalyst adds the expression result as a new column at the 
beginning.
+                # The ArrowUDTFWithPartition._remove_partition_by_exprs method 
should
+                # remove this added column, leaving only the original table 
columns.
+                column_names = table.column_names
+
+                # Verify we only have the original columns, not the partition 
expression
+                assert "col1" in column_names, f"Expected col1 in columns: 
{column_names}"
+                assert "col2" in column_names, f"Expected col2 in columns: 
{column_names}"
+                # The partition expression column should have been removed
+                assert len(column_names) == 2, (
+                    f"Expected only col1 and col2 after partition column 
removal, "
+                    f"but got: {column_names}"
+                )
+
+                col1_sum = pc.sum(table["col1"]).as_py()
+                col2_sum = pc.sum(table["col2"]).as_py()
+
+                result_table = pa.table(
+                    {
+                        "col1_sum": pa.array([col1_sum], type=pa.int32()),
+                        "col2_sum": pa.array([col2_sum], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 1),  # partition: 1+1=2
+            (1, 2),  # partition: 1+2=3
+            (2, 0),  # partition: 2+0=2
+            (2, 1),  # partition: 2+1=3
+        ]
+        input_df = self.spark.createDataFrame(test_data, "col1 int, col2 int")
+
+        self.spark.udtf.register("partition_column_test_udtf", 
PartitionColumnTestUDTF)
+        input_df.createOrReplaceTempView("test_partition_removal")
+
+        # Partition by col1 + col2 expression
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM partition_column_test_udtf(
+                TABLE(test_partition_removal)
+                PARTITION BY col1 + col2
+            )
+            ORDER BY col1_sum, col2_sum
+            """
+        )
+
+        expected_data = [
+            (3, 1),  # partition 2: sum of col1s (1+2), sum of col2s (1+0)
+            (3, 3),  # partition 3: sum of col1s (1+2), sum of col2s (2+1)
+        ]
+        expected_df = self.spark.createDataFrame(expected_data, "col1_sum int, 
col2_sum int")
+        assertDataFrameEqual(result_df, expected_df)
+
+    def 
test_arrow_udtf_partition_by_single_partition_multiple_input_partitions(self):
+        @arrow_udtf(returnType="partition_key int, count bigint, sum_value 
bigint")
+        class SinglePartitionUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+
+                # All rows should have the same partition key (constant value 
1)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                self._partition_key = partition_keys[0]
+                self._count += table.num_rows
+                self._sum += pc.sum(table["id"]).as_py()
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "count": pa.array([self._count], type=pa.int64()),
+                            "sum_value": pa.array([self._sum], 
type=pa.int64()),
+                        }
+                    )
+                    yield result_table
+
+        # Create DataFrame with 5 input partitions but all data will map to 
partition_key=1
+        # range(1, 10, 1, 5) creates ids from 1 to 9 with 5 partitions
+        input_df = self.spark.range(1, 10, 1, 5).selectExpr(
+            "1 as partition_key", "id"  # constant partition key
+        )
+
+        self.spark.udtf.register("single_partition_udtf", SinglePartitionUDTF)
+        input_df.createOrReplaceTempView("test_single_partition")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM single_partition_udtf(
+                TABLE(test_single_partition)
+                PARTITION BY partition_key
+            )
+            """
+        )
+
+        # All 9 rows (1 through 9) should be in a single partition with key=1
+        expected_data = [(1, 9, 45)]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count bigint, sum_value bigint"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_skip_rest_of_input(self):
+        from pyspark.sql.functions import SkipRestOfInputTableException
+
+        @arrow_udtf(returnType="partition_key int, rows_processed int, 
last_value int")
+        class SkipRestUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._rows_processed = 0
+                self._last_value = None
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                partition_keys = pc.unique(table["partition_key"]).to_pylist()
+                assert len(partition_keys) == 1, f"Expected one partition key, 
got {partition_keys}"
+                self._partition_key = partition_keys[0]
+
+                # Process rows one by one and stop after processing 2 rows per 
partition
+                values = table["value"].to_pylist()
+                for value in values:
+                    self._rows_processed += 1
+                    self._last_value = value
+
+                    # Skip rest of the partition after processing 2 rows
+                    if self._rows_processed >= 2:
+                        msg = f"Skipping partition {self._partition_key} "
+                        msg += f"after {self._rows_processed} rows"
+                        raise SkipRestOfInputTableException(msg)
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                if self._partition_key is not None:
+                    result_table = pa.table(
+                        {
+                            "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                            "rows_processed": pa.array([self._rows_processed], 
type=pa.int32()),
+                            "last_value": pa.array([self._last_value], 
type=pa.int32()),
+                        }
+                    )
+                    yield result_table
+
+        # Create test data with multiple partitions, each having more than 2 
rows
+        test_data = [
+            (1, 10),
+            (1, 20),
+            (1, 30),  # This should be skipped
+            (1, 40),  # This should be skipped
+            (2, 50),
+            (2, 60),
+            (2, 70),  # This should be skipped
+            (3, 80),
+            (3, 90),
+            (3, 100),  # This should be skipped
+            (3, 110),  # This should be skipped
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("skip_rest_udtf", SkipRestUDTF)
+        input_df.createOrReplaceTempView("test_skip_rest")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM skip_rest_udtf(
+                TABLE(test_skip_rest)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key
+            """
+        )
+
+        # Each partition should only process 2 rows before skipping the rest
+        expected_data = [
+            (1, 2, 20),  # Processed rows 10, 20, then skipped 30, 40
+            (2, 2, 60),  # Processed rows 50, 60, then skipped 70
+            (3, 2, 90),  # Processed rows 80, 90, then skipped 100, 110
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, rows_processed int, last_value 
int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_partition_by_null_values(self):
+        @arrow_udtf(returnType="partition_key int, count int, non_null_sum 
int")
+        class NullPartitionUDTF:
+            def __init__(self):
+                self._partition_key = None
+                self._count = 0
+                self._non_null_sum = 0
+
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow.compute as pc
+
+                table = pa.table(table_data)
+                # Handle null partition keys
+                partition_keys = table["partition_key"]
+                unique_keys = pc.unique(partition_keys).to_pylist()
+
+                # Should have exactly one unique value (either a value or None)
+                assert len(unique_keys) == 1, f"Expected one partition key, 
got {unique_keys}"
+                self._partition_key = unique_keys[0]
+
+                # Count rows and sum non-null values
+                self._count += table.num_rows
+                values = table["value"]
+                # Use PyArrow compute to handle nulls properly
+                non_null_values = pc.drop_null(values)
+                if len(non_null_values) > 0:
+                    self._non_null_sum += pc.sum(non_null_values).as_py()
+
+                return iter(())
+
+            def terminate(self) -> Iterator["pa.Table"]:
+                # Return results even for null partition keys
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([self._partition_key], 
type=pa.int32()),
+                        "count": pa.array([self._count], type=pa.int32()),
+                        "non_null_sum": pa.array([self._non_null_sum], 
type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        # Test data with null partition keys and null values
+        test_data = [
+            (1, 10),
+            (1, None),  # null value in partition 1
+            (None, 20),  # null partition key
+            (None, 30),  # null partition key
+            (2, 40),
+            (2, None),  # null value in partition 2
+            (None, None),  # both null
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("null_partition_udtf", NullPartitionUDTF)
+        input_df.createOrReplaceTempView("test_null_partitions")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM null_partition_udtf(
+                TABLE(test_null_partitions)
+                PARTITION BY partition_key
+                ORDER BY value
+            )
+            ORDER BY partition_key NULLS FIRST
+            """
+        )
+
+        # Expected: null partition gets grouped together, nulls in values are 
handled
+        expected_data = [
+            (None, 3, 50),  # null partition: 3 rows, sum of non-null values = 
20+30 = 50
+            (1, 2, 10),  # partition 1: 2 rows, sum of non-null values = 10
+            (2, 2, 40),  # partition 2: 2 rows, sum of non-null values = 40
+        ]
+        expected_df = self.spark.createDataFrame(
+            expected_data, "partition_key int, count int, non_null_sum int"
+        )
+        assertDataFrameEqual(result_df, expected_df)
+
+    def test_arrow_udtf_with_empty_table(self):
+        @arrow_udtf(returnType="result string")
+        class EmptyTableUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                import pyarrow as pa
+
+                # This should not be called for empty tables

Review Comment:
   How about raising an error to check it won't fail if this should not be 
called?



##########
python/pyspark/worker.py:
##########
@@ -1514,10 +1514,288 @@ def _remove_partition_by_exprs(self, arg: Any) -> Any:
             else:
                 return arg
 
+    class ArrowUDTFWithPartition:
+        """
+        Implements logic for an Arrow UDTF (SQL_ARROW_UDTF) that accepts a 
TABLE argument
+        with one or more PARTITION BY expressions.
+
+        Arrow UDTFs receive data as PyArrow RecordBatch objects instead of 
individual Row
+        objects.
+
+        Example table:
+            CREATE TABLE t (c1 INT, c2 INT) USING delta;
+
+        Example queries:
+            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2);
+            partition_child_indexes: 0, 1.
+
+            SELECT * FROM my_udtf(TABLE (t) PARTITION BY c1, c2 + 4);
+            partition_child_indexes: 0, 2 (adds a projection for "c2 + 4").
+        """
+
+        def __init__(self, create_udtf: Callable, partition_child_indexes: 
list):
+            """
+            Create a new instance that wraps the provided Arrow UDTF with 
partitioning
+            logic.
+
+            Parameters
+            ----------
+            create_udtf: function
+                Function that creates a new instance of the Arrow UDTF to 
invoke.
+            partition_child_indexes: list
+                Zero-based indexes of input-table columns that contain 
projected
+                partitioning expressions.
+            """
+            self._create_udtf: Callable = create_udtf
+            self._udtf = create_udtf()
+            self._partition_child_indexes: list = partition_child_indexes
+            # Track last partition key from previous batch
+            self._last_partition_key: Optional[Tuple[Any, ...]] = None
+            self._eval_raised_skip_rest_of_input_table: bool = False
+
+        def eval(self, *args, **kwargs) -> Iterator:
+            """Handle partitioning logic for Arrow UDTFs that receive 
RecordBatch objects."""
+            import pyarrow as pa
+
+            # Get the original batch with partition columns
+            original_batch = self._get_table_arg(list(args) + 
list(kwargs.values()))
+            if not isinstance(original_batch, pa.RecordBatch):
+                # Arrow UDTFs with PARTITION BY must have a TABLE argument that
+                # results in a PyArrow RecordBatch
+                raise PySparkRuntimeError(
+                    errorClass="INVALID_ARROW_UDTF_TABLE_ARGUMENT",
+                    messageParameters={
+                        "actual_type": str(type(original_batch))
+                        if original_batch is not None
+                        else "None"
+                    },
+                )
+
+            # Remove partition columns to get the filtered arguments
+            filtered_args = [self._remove_partition_by_exprs(arg) for arg in 
args]
+            filtered_kwargs = {
+                key: self._remove_partition_by_exprs(value) for (key, value) 
in kwargs.items()
+            }
+
+            # Get the filtered RecordBatch (without partition columns)
+            filtered_batch = self._get_table_arg(filtered_args + 
list(filtered_kwargs.values()))
+
+            # Process the RecordBatch by partitions
+            yield from self._process_arrow_batch_by_partitions(
+                original_batch, filtered_batch, filtered_args, filtered_kwargs
+            )
+
+        def _process_arrow_batch_by_partitions(
+            self, original_batch, filtered_batch, filtered_args, 
filtered_kwargs
+        ) -> Iterator:
+            """Process an Arrow RecordBatch by splitting it into partitions.
+
+            Since Catalyst guarantees that rows with the same partition key 
are contiguous,
+            we can use efficient boundary detection instead of group_by.
+
+            Handles two scenarios:
+            1. Multiple partitions within a single RecordBatch (using boundary 
detection)
+            2. Same partition key continuing from previous RecordBatch 
(tracking state)
+            """
+            import pyarrow as pa
+
+            if self._partition_child_indexes:
+                # Detect partition boundaries.
+                boundaries = self._detect_partition_boundaries(original_batch)
+
+                # Process each contiguous partition
+                for i in range(len(boundaries) - 1):
+                    start_idx = boundaries[i]
+                    end_idx = boundaries[i + 1]
+
+                    # Get the partition key for this segment
+                    partition_key = tuple(
+                        original_batch.column(idx)[start_idx].as_py()
+                        for idx in self._partition_child_indexes
+                    )
+
+                    # Check if this is a continuation of the previous batch's 
partition
+                    is_new_partition = (
+                        self._last_partition_key is not None
+                        and partition_key != self._last_partition_key
+                    )
+
+                    if is_new_partition:
+                        # Previous partition ended, call terminate
+                        if hasattr(self._udtf, "terminate"):
+                            terminate_result = self._udtf.terminate()
+                            if terminate_result is not None:
+                                for table in terminate_result:
+                                    yield table
+                        # Create new UDTF instance for new partition
+                        self._udtf = self._create_udtf()
+                        self._eval_raised_skip_rest_of_input_table = False
+
+                    # Slice the filtered batch for this partition
+                    partition_batch = filtered_batch.slice(start_idx, end_idx 
- start_idx)
+
+                    # Update the last partition key
+                    self._last_partition_key = partition_key
+
+                    # Update filtered args to use the partition batch
+                    partition_filtered_args = []
+                    for arg in filtered_args:
+                        if isinstance(arg, pa.RecordBatch):
+                            partition_filtered_args.append(partition_batch)
+                        else:
+                            partition_filtered_args.append(arg)
+
+                    partition_filtered_kwargs = {}
+                    for key, value in filtered_kwargs.items():
+                        if isinstance(value, pa.RecordBatch):
+                            partition_filtered_kwargs[key] = partition_batch
+                        else:
+                            partition_filtered_kwargs[key] = value
+
+                    # Call the UDTF with this partition's data
+                    if not self._eval_raised_skip_rest_of_input_table:
+                        try:
+                            result = self._udtf.eval(
+                                *partition_filtered_args, 
**partition_filtered_kwargs
+                            )
+                            if result is not None:
+                                for table in result:
+                                    yield table
+                        except SkipRestOfInputTableException:
+                            # Skip remaining rows in this partition
+                            self._eval_raised_skip_rest_of_input_table = True
+
+                    # Don't terminate here - let the next batch or final 
terminate handle it
+            else:
+                # No partitions, process the entire batch as one group
+                try:
+                    result = self._udtf.eval(*filtered_args, **filtered_kwargs)
+                    if result is not None:
+                        # result is an iterator of PyArrow Tables (for Arrow 
UDTFs)
+                        for table in result:
+                            yield table

Review Comment:
   ditto.
   ```suggestion
                           yield from result
   ```



##########
python/pyspark/sql/tests/arrow/test_arrow_udtf.py:
##########
@@ -730,6 +730,496 @@ def eval(self, x: "pa.Array", y: "pa.Array") -> 
Iterator["pa.Table"]:
         expected_df2 = self.spark.createDataFrame([(7, 3, 10)], "x int, y int, 
sum int")
         assertDataFrameEqual(sql_result_df2, expected_df2)
 
+    def test_arrow_udtf_with_partition_by(self):
+        @arrow_udtf(returnType="partition_key int, sum_value int")
+        class SumUDTF:
+            def eval(self, table_data: "pa.RecordBatch") -> 
Iterator["pa.Table"]:
+                table = pa.table(table_data)
+                partition_key = pc.unique(table["partition_key"]).to_pylist()
+                assert (
+                    len(partition_key) == 1
+                ), f"Expected exactly one partition key, got {partition_key}"
+                sum_value = pc.sum(table["value"]).as_py()
+                result_table = pa.table(
+                    {
+                        "partition_key": pa.array([partition_key[0]], 
type=pa.int32()),
+                        "sum_value": pa.array([sum_value], type=pa.int32()),
+                    }
+                )
+                yield result_table
+
+        test_data = [
+            (1, 10),
+            (2, 5),
+            (1, 20),
+            (2, 15),
+            (1, 30),
+            (3, 100),
+        ]
+        input_df = self.spark.createDataFrame(test_data, "partition_key int, 
value int")
+
+        self.spark.udtf.register("sum_udtf", SumUDTF)
+        input_df.createOrReplaceTempView("test_data")
+
+        result_df = self.spark.sql(
+            """
+            SELECT * FROM sum_udtf(TABLE(test_data) PARTITION BY partition_key)
+        """
+        )
+
+        expected_data = [
+            (1, 60),
+            (2, 20),
+            (3, 100),
+        ]

Review Comment:
   This test may be potentially flaky?
   IIUC, `eval` can be called multiple times per partition, so this 
implementation can yield multiple tables per parittion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org


Reply via email to