alex-balikov commented on code in PR #37893:
URL: https://github.com/apache/spark/pull/37893#discussion_r976868040


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -216,6 +218,125 @@ def applyInPandas(
         jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
         return DataFrame(jdf, self.session)
 
+    def applyInPandasWithState(
+        self,
+        func: "PandasGroupedMapFunctionWithState",
+        outputStructType: Union[StructType, str],
+        stateStructType: Union[StructType, str],
+        outputMode: str,
+        timeoutConf: str,
+    ) -> DataFrame:
+        """
+        Applies the given function to each group of data, while maintaining a 
user-defined
+        per-group state. The result Dataset will represent the flattened 
record returned by the
+        function.
+
+        For a streaming Dataset, the function will be invoked first for all 
input groups and then
+        for all timed out states where the input data is set to be empty. 
Updates to each group's
+        state will be saved across invocations.
+
+        The function should take parameters (key, 
Iterator[`pandas.DataFrame`], state) and
+        returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will 
be passed as a tuple

Review Comment:
   return another ...



##########
python/pyspark/worker.py:
##########
@@ -207,6 +209,89 @@ def wrapped(key_series, value_series):
     return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))]
 
 
+def wrap_grouped_map_pandas_udf_with_state(f, return_type):
+    """
+    Provides a new lambda instance wrapping user function of 
applyInPandasWithState.
+
+    The lambda instance receives (key series, iterator of value series, state) 
and performs
+    some conversion to be adapted with the signature of user function.
+
+    See the function doc of inner function `wrapped` for more details on what 
adapter does.
+    See the function doc of `mapper` function for
+    `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for 
more details on
+    the input parameters of lambda function.
+
+    Along with the returned iterator, the lambda instance will also produce 
the return_type as
+    converted to the arrow schema.
+    """
+
+    def wrapped(key_series, value_series_gen, state):
+        """
+        Provide an adapter of the user function performing below:
+
+        - Extract the first value of all columns in key series and produce as 
a tuple.
+        - If the state has timed out, call the user function with empty pandas 
DataFrame.
+        - If not, construct a new generator which converts each element of 
value series to
+          pandas DataFrame (lazy evaluation), and call the user function with 
the generator
+        - Verify each element of returned iterator to check the schema of 
pandas DataFrame.
+        """
+        import pandas as pd
+
+        key = tuple(s[0] for s in key_series)
+
+        if state.hasTimedOut:
+            # Timeout processing pass empty iterator. Here we return an empty 
DataFrame instead.
+            values = [
+                pd.DataFrame(columns=pd.concat(next(value_series_gen), 
axis=1).columns),
+            ]
+        else:
+            values = (pd.concat(x, axis=1) for x in value_series_gen)
+
+        result_iter = f(key, values, state)
+
+        def verify_element(result):
+            if not isinstance(result, pd.DataFrame):
+                raise TypeError(
+                    "The type of element in return iterator of the 
user-defined function "
+                    "should be pandas.DataFrame, but is 
{}".format(type(result))
+                )
+            # the number of columns of result have to match the return type
+            # but it is fine for result to have no columns at all if it is 
empty
+            if not (
+                len(result.columns) == len(return_type) or len(result.columns) 
== 0 and result.empty

Review Comment:
   may be it is just me but I would suggest adding parentheses so we do not 
rely on and/or priority 



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -371,3 +373,354 @@ def load_stream(self, stream):
                 raise ValueError(
                     "Invalid number of pandas.DataFrames in group 
{0}".format(dataframes_in_group)
                 )
+
+
+class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
applyInPandasWithState.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    state_object_schema : StructType
+        The type of state object represented as Spark SQL type
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+        self,
+        timezone,
+        safecheck,
+        assign_cols_by_name,
+        state_object_schema,
+        arrow_max_records_per_batch,
+    ):
+        super(ApplyInPandasWithStateSerializer, self).__init__(
+            timezone, safecheck, assign_cols_by_name
+        )
+        self.pickleSer = CPickleSerializer()
+        self.utf8_deserializer = UTF8Deserializer()
+        self.state_object_schema = state_object_schema
+
+        self.result_state_df_type = StructType(
+            [
+                StructField("properties", StringType()),
+                StructField("keyRowAsUnsafe", BinaryType()),
+                StructField("object", BinaryType()),
+                StructField("oldTimeoutTimestamp", LongType()),
+            ]
+        )
+
+        self.result_state_pdf_arrow_type = 
to_arrow_type(self.result_state_df_type)
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.
+        """
+
+        import pyarrow as pa
+        import json
+        from itertools import groupby
+        from pyspark.sql.streaming.state import GroupState
+
+        def construct_state(state_info_col):
+            """
+            Construct state instance from the value of state information 
column.
+            """
+
+            state_info_col_properties = state_info_col["properties"]
+            state_info_col_key_row = state_info_col["keyRowAsUnsafe"]
+            state_info_col_object = state_info_col["object"]
+
+            state_properties = json.loads(state_info_col_properties)
+            if state_info_col_object:
+                state_object = self.pickleSer.loads(state_info_col_object)
+            else:
+                state_object = None
+            state_properties["optionalValue"] = state_object
+
+            return GroupState(
+                keyAsUnsafe=state_info_col_key_row,
+                valueSchema=self.state_object_schema,
+                **state_properties,
+            )
+
+        def gen_data_and_state(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of
+            `(a list of pandas.Series, state)`.
+
+            The logic on deserialization is following:
+
+            1. Read the entire data part from Arrow RecordBatch.
+            2. Read the entire state information part from Arrow RecordBatch.
+            3. Loop through each state information:
+               3.A. Extract the data out from entire data via the information 
of data range.
+               3.B. Construct a new state instance if the state information is 
the first occurrence
+                    for the current grouping key.
+               3.C. Leverage existing new state instance if the state instance 
is already available

Review Comment:
   Leverage the existing state instance if it is already available for the 
current grouping key...



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to