This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b3d5bc0c109 [SPARK-45362][PYTHON] Project out PARTITION BY expressions 
before Python UDTF 'eval' method consumes them
b3d5bc0c109 is described below

commit b3d5bc0c10908aa66510844eaabc43b6764dd7c0
Author: Daniel Tenedorio <daniel.tenedo...@databricks.com>
AuthorDate: Thu Sep 28 14:02:46 2023 -0700

    [SPARK-45362][PYTHON] Project out PARTITION BY expressions before Python 
UDTF 'eval' method consumes them
    
    ### What changes were proposed in this pull request?
    
    This PR projects out PARTITION BY expressions before Python UDTF 'eval' 
method consumes them.
    
    Before this PR, if a query included this `PARTITION BY` clause:
    
    ```
    SELECT * FROM udtf((SELECT a, b FROM TABLE t) PARTITION BY (c, d))
    ```
    
    Then the `eval` method received four columns in each row: `a, b, c, d`.
    
    After this PR, the `eval` method only receives two columns: `a, b`, as 
expected.
    
    ### Why are the changes needed?
    
    This makes the Python UDTF `TABLE` columns consistently match what the 
`eval` method receives, as expected.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, see above.
    
    ### How was this patch tested?
    
    This PR adds new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43156 from dtenedor/project-out-partition-exprs.
    
    Authored-by: Daniel Tenedorio <daniel.tenedo...@databricks.com>
    Signed-off-by: Takuya UESHIN <ues...@databricks.com>
---
 python/pyspark/sql/tests/test_udtf.py | 12 ++++++++++++
 python/pyspark/worker.py              | 31 +++++++++++++++++++++++++++----
 2 files changed, 39 insertions(+), 4 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 97d5190a506..a1d82056c50 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -2009,6 +2009,10 @@ class BaseUDTFTestsMixin:
                 self._partition_col = None
 
             def eval(self, row: Row):
+                # Make sure that the PARTITION BY expressions were projected 
out.
+                assert len(row.asDict().items()) == 2
+                assert "partition_col" in row
+                assert "input" in row
                 self._sum += row["input"]
                 if self._partition_col is not None and self._partition_col != 
row["partition_col"]:
                     # Make sure that all values of the partitioning column are 
the same
@@ -2092,6 +2096,10 @@ class BaseUDTFTestsMixin:
                 self._partition_col = None
 
             def eval(self, row: Row, partition_col: str):
+                # Make sure that the PARTITION BY and ORDER BY expressions 
were projected out.
+                assert len(row.asDict().items()) == 2
+                assert "partition_col" in row
+                assert "input" in row
                 # Make sure that all values of the partitioning column are the 
same
                 # for each row consumed by this method for this instance of 
the class.
                 if self._partition_col is not None and self._partition_col != 
row[partition_col]:
@@ -2247,6 +2255,10 @@ class BaseUDTFTestsMixin:
                 )
 
             def eval(self, row: Row):
+                # Make sure that the PARTITION BY and ORDER BY expressions 
were projected out.
+                assert len(row.asDict().items()) == 2
+                assert "partition_col" in row
+                assert "input" in row
                 # Make sure that all values of the partitioning column are the 
same
                 # for each row consumed by this method for this instance of 
the class.
                 if self._partition_col is not None and self._partition_col != 
row["partition_col"]:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 77481704979..4cffb02a64a 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -51,7 +51,14 @@ from pyspark.sql.pandas.serializers import (
     ApplyInPandasWithStateSerializer,
 )
 from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import BinaryType, Row, StringType, StructType, 
_parse_datatype_json_string
+from pyspark.sql.types import (
+    BinaryType,
+    Row,
+    StringType,
+    StructType,
+    _create_row,
+    _parse_datatype_json_string,
+)
 from pyspark.util import fail_on_stopiteration, handle_worker_exception
 from pyspark import shuffle
 from pyspark.errors import PySparkRuntimeError, PySparkTypeError
@@ -735,7 +742,12 @@ def read_udtf(pickleSer, infile, eval_type):
                             yield row
                 self._udtf = self._create_udtf()
             if self._udtf.eval is not None:
-                result = self._udtf.eval(*args, **kwargs)
+                # Filter the arguments to exclude projected PARTITION BY 
values added by Catalyst.
+                filtered_args = [self._remove_partition_by_exprs(arg) for arg 
in args]
+                filtered_kwargs = {
+                    key: self._remove_partition_by_exprs(value) for (key, 
value) in kwargs.items()
+                }
+                result = self._udtf.eval(*filtered_args, **filtered_kwargs)
                 if result is not None:
                     for row in result:
                         yield row
@@ -752,10 +764,9 @@ def read_udtf(pickleSer, infile, eval_type):
                 prev_table_arg = self._get_table_arg(self._prev_arguments)
                 cur_partitions_args = []
                 prev_partitions_args = []
-                for i in partition_child_indexes:
+                for i in self._partition_child_indexes:
                     cur_partitions_args.append(cur_table_arg[i])
                     prev_partitions_args.append(prev_table_arg[i])
-                self._prev_arguments = arguments
                 result = any(k != v for k, v in zip(cur_partitions_args, 
prev_partitions_args))
             self._prev_arguments = arguments
             return result
@@ -763,6 +774,18 @@ def read_udtf(pickleSer, infile, eval_type):
         def _get_table_arg(self, inputs: list) -> Row:
             return [x for x in inputs if type(x) is Row][0]
 
+        def _remove_partition_by_exprs(self, arg: Any) -> Any:
+            if isinstance(arg, Row):
+                new_row_keys = []
+                new_row_values = []
+                for i, (key, value) in enumerate(zip(arg.__fields__, arg)):
+                    if i not in self._partition_child_indexes:
+                        new_row_keys.append(key)
+                        new_row_values.append(value)
+                return _create_row(new_row_keys, new_row_values)
+            else:
+                return arg
+
     # Instantiate the UDTF class.
     try:
         if len(partition_child_indexes) > 0:


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to