gaogaotiantian commented on code in PR #55222:
URL: https://github.com/apache/spark/pull/55222#discussion_r3061546329


##########
python/pyspark/worker.py:
##########
@@ -2575,6 +2510,35 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
         for i in range(num_udfs)
     ]
 
+    def extract_key_value_indexes(grouped_arg_offsets):
+        """
+        Helper function to extract the key and value indexes from arg_offsets 
for the grouped and
+        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for 
equivalent scala code.
+
+        Parameters
+        ----------
+        grouped_arg_offsets:  list
+            List containing the key and value indexes of columns of the
+            DataFrames to be passed to the udf. It consists of n repeating 
groups where n is the
+            number of DataFrames.  Each group has the following format:
+                group[0]: length of group
+                group[1]: length of key indexes
+                group[2.. group[1] +2]: key attributes

Review Comment:
   We can make this real python right? `group[2: group[1] + 2]`? Also let's 
talk about return type too.



##########
python/pyspark/worker.py:
##########
@@ -2575,6 +2510,35 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
         for i in range(num_udfs)
     ]
 
+    def extract_key_value_indexes(grouped_arg_offsets):
+        """
+        Helper function to extract the key and value indexes from arg_offsets 
for the grouped and
+        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for 
equivalent scala code.
+
+        Parameters
+        ----------
+        grouped_arg_offsets:  list
+            List containing the key and value indexes of columns of the
+            DataFrames to be passed to the udf. It consists of n repeating 
groups where n is the
+            number of DataFrames.  Each group has the following format:
+                group[0]: length of group
+                group[1]: length of key indexes
+                group[2.. group[1] +2]: key attributes
+                group[group[1] +3 group[0]]: value attributes
+        """
+        parsed = []
+        idx = 0
+        while idx < len(grouped_arg_offsets):
+            offsets_len = grouped_arg_offsets[idx]
+            idx += 1
+            offsets = grouped_arg_offsets[idx : idx + offsets_len]
+            split_index = offsets[0] + 1
+            offset_keys = offsets[1:split_index]
+            offset_values = offsets[split_index:]

Review Comment:
   I think having temporary variables like `split_index` does not make the code 
easier to read. If we already have the comment above, we could probably just do 
the same as comments right? Make them consistent.



##########
python/pyspark/worker.py:
##########
@@ -2575,6 +2510,35 @@ def read_udfs(pickleSer, infile, eval_type, runner_conf, 
eval_conf):
         for i in range(num_udfs)
     ]
 
+    def extract_key_value_indexes(grouped_arg_offsets):
+        """
+        Helper function to extract the key and value indexes from arg_offsets 
for the grouped and
+        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for 
equivalent scala code.
+
+        Parameters
+        ----------
+        grouped_arg_offsets:  list
+            List containing the key and value indexes of columns of the
+            DataFrames to be passed to the udf. It consists of n repeating 
groups where n is the
+            number of DataFrames.  Each group has the following format:
+                group[0]: length of group
+                group[1]: length of key indexes
+                group[2.. group[1] +2]: key attributes

Review Comment:
   Just realize that this is directly copied. You can either keep it as it is 
or make some changes while moving it. I'm fine either way.



##########
python/pyspark/worker.py:
##########
@@ -2833,6 +2797,144 @@ def grouped_func(
         # profiling is not supported for UDF
         return grouped_func, None, ser, ser
 
+    if eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF:
+        import pyarrow as pa
+
+        assert num_udfs == 1, "One GROUPED_MAP_ARROW UDF expected here."
+        grouped_udf, arg_offsets, return_type, num_udf_args = udfs[0]
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        arrow_return_type = to_arrow_type(
+            return_type, timezone="UTC", 
prefers_large_types=runner_conf.use_large_var_types
+        )
+        if runner_conf.assign_cols_by_name:
+            expected_cols_and_types = {
+                col.name: to_arrow_type(col.dataType, timezone="UTC") for col 
in return_type.fields
+            }
+        else:
+            expected_cols_and_types = [
+                (col.name, to_arrow_type(col.dataType, timezone="UTC"))
+                for col in return_type.fields
+            ]
+

Review Comment:
   Let's also add an assertion about the length of `parsed_offsets` (`0` 
probably?)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to