alex-balikov commented on code in PR #37893: URL: https://github.com/apache/spark/pull/37893#discussion_r976868040
########## python/pyspark/sql/pandas/group_ops.py: ########## @@ -216,6 +218,125 @@ def applyInPandas( jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + """ + Applies the given function to each group of data, while maintaining a user-defined + per-group state. The result Dataset will represent the flattened record returned by the + function. + + For a streaming Dataset, the function will be invoked first for all input groups and then + for all timed out states where the input data is set to be empty. Updates to each group's + state will be saved across invocations. + + The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and + returns another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple Review Comment: return another ... ########## python/pyspark/worker.py: ########## @@ -207,6 +209,89 @@ def wrapped(key_series, value_series): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + """ + Provides a new lambda instance wrapping user function of applyInPandasWithState. + + The lambda instance receives (key series, iterator of value series, state) and performs + some conversion to be adapted with the signature of user function. + + See the function doc of inner function `wrapped` for more details on what adapter does. + See the function doc of `mapper` function for + `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on + the input parameters of lambda function. + + Along with the returned iterator, the lambda instance will also produce the return_type as + converted to the arrow schema. + """ + + def wrapped(key_series, value_series_gen, state): + """ + Provide an adapter of the user function performing below: + + - Extract the first value of all columns in key series and produce as a tuple. + - If the state has timed out, call the user function with empty pandas DataFrame. + - If not, construct a new generator which converts each element of value series to + pandas DataFrame (lazy evaluation), and call the user function with the generator + - Verify each element of returned iterator to check the schema of pandas DataFrame. + """ + import pandas as pd + + key = tuple(s[0] for s in key_series) + + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + values = [ + pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), + ] + else: + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result_iter = f(key, values, state) + + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) or len(result.columns) == 0 and result.empty Review Comment: may be it is just me but I would suggest adding parentheses so we do not rely on and/or priority ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -371,3 +373,354 @@ def load_stream(self, stream): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for applyInPandasWithState. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + state_object_schema : StructType + The type of state object represented as Spark SQL type + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + arrow_max_records_per_batch, + ): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name + ) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema + + self.result_state_df_type = StructType( + [ + StructField("properties", StringType()), + StructField("keyRowAsUnsafe", BinaryType()), + StructField("object", BinaryType()), + StructField("oldTimeoutTimestamp", LongType()), + ] + ) + + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.arrow_max_records_per_batch = arrow_max_records_per_batch + + def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of pair + (data chunk, state), and convert the data into a list of pandas.Series. + + Please refer the doc of inner function `gen_data_and_state` for more details how + this function works in overall. + + In addition, this function further groups the return of `gen_data_and_state` by the state + instance (same semantic as grouping by grouping key) and produces an iterator of data + chunks for each group, so that the caller can lazily materialize the data chunk. + """ + + import pyarrow as pa + import json + from itertools import groupby + from pyspark.sql.streaming.state import GroupState + + def construct_state(state_info_col): + """ + Construct state instance from the value of state information column. + """ + + state_info_col_properties = state_info_col["properties"] + state_info_col_key_row = state_info_col["keyRowAsUnsafe"] + state_info_col_object = state_info_col["object"] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + return GroupState( + keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties, + ) + + def gen_data_and_state(batches): + """ + Deserialize ArrowRecordBatches and return a generator of + `(a list of pandas.Series, state)`. + + The logic on deserialization is following: + + 1. Read the entire data part from Arrow RecordBatch. + 2. Read the entire state information part from Arrow RecordBatch. + 3. Loop through each state information: + 3.A. Extract the data out from entire data via the information of data range. + 3.B. Construct a new state instance if the state information is the first occurrence + for the current grouping key. + 3.C. Leverage existing new state instance if the state instance is already available Review Comment: Leverage the existing state instance if it is already available for the current grouping key... -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org