This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 603dc509821 [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark 603dc509821 is described below commit 603dc5098217d9580f611873165d25392f41cdfe Author: Jungtaek Lim <kabhwan.opensou...@gmail.com> AuthorDate: Thu Sep 22 12:35:07 2022 +0900 [SPARK-40434][SS][PYTHON] Implement applyInPandasWithState in PySpark ### What changes were proposed in this pull request? This PR proposes to introduce the new API `applyInPandasWithState` in PySpark, which provides the functionality to perform arbitrary stateful processing in Structured Streaming. This will be a pair API with applyInPandas - applyInPandas in PySpark covers the use case of flatMapGroups in Scala/Java API, applyInPandasWithState in PySpark covers the use case of flatMapGroupsWithState in Scala/Java API. The signature of API follows: ``` # call this function after groupBy def applyInPandasWithState( self, func: "PandasGroupedMapFunctionWithState", outputStructType: Union[StructType, str], stateStructType: Union[StructType, str], outputMode: str, timeoutConf: str, ) -> DataFrame ``` and the signature of user function follows: ``` def func( key: Tuple, pdf_iter: Iterator[pandas.DataFrame], state: GroupStateImpl ) -> Iterator[pandas.DataFrame] ``` (Please refer the code diff for function doc of new function.) Major design choices which differ from existing APIs: 1. The new API is untyped, while flatMapGroupsWithState in typed API. This is based on the nature of Python language - it's really duck typing and type definition is just a hint. We don't have the implementation of typed API for PySpark DataFrame. This leads us to design the API to be untyped, meaning, all types for (input, state, output) should be Row-compatible. While we don't require end users to deal with `Row` directly, the model they will use for state and output must be convertible to Row with default encoder. If they want the python type for state which is not compatible with Row (e.g. custom class), they need to pickle and use BinaryType to store it. This requires end users to specify the type of state and output via Spark SQL schema in the method. Note that this helps to ensure compatibility for state data across Spark versions, as long as the encoders for 1) python type -> python Row and 2) python Row -> UnsafeRow are not changed. We won't change the underlying data layout for UnsafeRow, as it will break all of existing stateful query. 2. The new API will produce Pandas DataFrame to user function, while flatMapGroupsWithState produces iterator of rows. We decided to follow the user experience applyInPandas provides for both consistency and performance (Arrow batching, vectorization, etc). This leads us to design the user function to leverage pandas DataFrame rather than iterator of rows. While this leads inconsistency of the UX from the Scala/Java API, we don't think this will come up as a problem since Pandas is considered as de-facto standard for Python data scientists. 3. The new API will produce iterator of Pandas DataFrame to user function and also require to return iterator of Pandas DataFrame to address scalability. There is known limitation of applyInPandas, scalability. It basically requires data in a specific group to be fit into memory. During the design phase of new API, we decided to address the scalability rather than inheriting the limitation. To address the scalability, we tweak the user function to receive an iterator (generator) of Pandas DataFrame instead of a single Pandas DataFrame, and also return an iterator (generator) of Pandas DataFrame. We think it does not hurt the UX too much, as for-each and yield would be enough to deal with the requirement of dealing with iterator. Implementation perspective, we split the data in a specific group to multiple chunks, which each chunk is stored and sent as "an" Arrow RecordBatch, and then finally materialized to "a" pandas DataFrame. This way, as long as end users don't materialize lots of pandas DataFrames from the iterator at the same time, only one chunk will be materialized into memory which is scalable. Similar logic applies to the output of user function, hence scalable as well. 4. The new API also bin-packs the data with multiple groups into "an" Arrow RecordBatch. Given the API is mainly used for streaming workload, it could be high likely that the volume of data in a specific group may not be huge enough to leverage the benefit of Arrow columnar batching, which would hurt the performance. To address this, we also do the opposite thing what we do for scalability, bin-pack. That said, an Arrow RecordBatch can contain data for multiple groups, as well as a part of data for specific group. This address both aspects of concerns together, scalabilit [...] Note that we are not implementing all of features Scala/Java API provide from the initial phase. e.g. Support for batch query and support for initial state will be left as TODO. ### Why are the changes needed? PySpark users don't have a way to perform arbitrary stateful processing in Structured Streaming and being forced to use either Java or Scala which is unacceptable for users in many cases. This PR enables PySpark users to deal with it without moving to Java/Scala world. ### Does this PR introduce _any_ user-facing change? Yes. We are exposing new public API in PySpark which performs arbitrary stateful processing. ### How was this patch tested? N/A. We will make sure test suites are constructed via E2E manner under [SPARK-40431](https://issues.apache.org/jira/browse/SPARK-40431) - #37894 Closes #37893 from HeartSaVioR/SPARK-40434-on-top-of-SPARK-40433-SPARK-40432. Lead-authored-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 2 + python/pyspark/sql/pandas/_typing/__init__.pyi | 6 + python/pyspark/sql/pandas/functions.py | 2 + python/pyspark/sql/pandas/group_ops.py | 125 +++++++- python/pyspark/sql/pandas/serializers.py | 355 ++++++++++++++++++++- python/pyspark/sql/streaming/state.py | 55 +++- python/pyspark/sql/udf.py | 9 +- python/pyspark/worker.py | 143 +++++++++ .../analysis/UnsupportedOperationChecker.scala | 62 ++++ .../plans/logical/pythonLogicalOperators.scala | 34 ++ .../spark/sql/RelationalGroupedDataset.scala | 45 +++ .../spark/sql/execution/SparkStrategies.scala | 23 ++ .../spark/sql/execution/arrow/ArrowWriter.scala | 16 +- .../ApplyInPandasWithStatePythonRunner.scala | 223 +++++++++++++ .../python/ApplyInPandasWithStateWriter.scala | 276 ++++++++++++++++ .../python/FlatMapCoGroupsInPandasExec.scala | 4 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../FlatMapGroupsInPandasWithStateExec.scala | 214 +++++++++++++ .../sql/execution/python/PandasGroupUtils.scala | 7 +- .../sql/execution/python/PythonArrowInput.scala | 1 - .../execution/streaming/IncrementalExecution.scala | 9 + .../sql/execution/streaming/state/package.scala | 2 +- 23 files changed, 1599 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 5a13674e8bf..7b31fa93c32 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -53,6 +53,7 @@ private[spark] object PythonEvalType { val SQL_MAP_PANDAS_ITER_UDF = 205 val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 + val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -65,6 +66,7 @@ private[spark] object PythonEvalType { case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" + case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7ef0014ae75..5f4f4d494e1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -105,6 +105,7 @@ if TYPE_CHECKING: PandasMapIterUDFType, PandasCogroupedMapUDFType, ArrowMapIterUDFType, + PandasGroupedMapUDFWithStateType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -147,6 +148,7 @@ class PythonEvalType: SQL_MAP_PANDAS_ITER_UDF: "PandasMapIterUDFType" = 205 SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 + SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 def portable_hash(x: Hashable) -> int: diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 27ac64a7238..acca8c00f2a 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -30,6 +30,7 @@ from typing_extensions import Protocol, Literal from types import FunctionType from pyspark.sql._typing import LiteralType +from pyspark.sql.streaming.state import GroupState from pandas.core.frame import DataFrame as PandasDataFrame from pandas.core.series import Series as PandasSeries from numpy import ndarray as NDArray @@ -51,6 +52,7 @@ PandasScalarIterUDFType = Literal[204] PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] +PandasGroupedMapUDFWithStateType = Literal[208] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -256,6 +258,10 @@ PandasGroupedMapFunction = Union[ Callable[[Any, DataFrameLike], DataFrameLike], ] +PandasGroupedMapFunctionWithState = Callable[ + [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike] +] + class PandasVariadicGroupedAggFunction(Protocol): def __call__(self, *_: SeriesLike) -> LiteralType: ... diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index 94fabdbb295..d0f81e2f633 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -369,6 +369,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, None, ]: # None means it should infer the type from type hints. @@ -402,6 +403,7 @@ def _create_pandas_udf(f, returnType, evalType): PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ]: # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered # at `apply` instead. diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 6178433573e..0945c0078a2 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -15,18 +15,20 @@ # limitations under the License. # import sys -from typing import List, Union, TYPE_CHECKING +from typing import List, Union, TYPE_CHECKING, cast import warnings from pyspark.rdd import PythonEvalType from pyspark.sql.column import Column from pyspark.sql.dataframe import DataFrame -from pyspark.sql.types import StructType +from pyspark.sql.streaming.state import GroupStateTimeout +from pyspark.sql.types import StructType, _parse_datatype_string if TYPE_CHECKING: from pyspark.sql.pandas._typing import ( GroupedMapPandasUserDefinedFunction, PandasGroupedMapFunction, + PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -216,6 +218,125 @@ class PandasGroupedOpsMixin: jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) return DataFrame(jdf, self.session) + def applyInPandasWithState( + self, + func: "PandasGroupedMapFunctionWithState", + outputStructType: Union[StructType, str], + stateStructType: Union[StructType, str], + outputMode: str, + timeoutConf: str, + ) -> DataFrame: + """ + Applies the given function to each group of data, while maintaining a user-defined + per-group state. The result Dataset will represent the flattened record returned by the + function. + + For a streaming Dataset, the function will be invoked first for all input groups and then + for all timed out states where the input data is set to be empty. Updates to each group's + state will be saved across invocations. + + The function should take parameters (key, Iterator[`pandas.DataFrame`], state) and + return another Iterator[`pandas.DataFrame`]. The grouping key(s) will be passed as a tuple + of numpy data types, e.g., `numpy.int32` and `numpy.float64`. The state will be passed as + :class:`pyspark.sql.streaming.state.GroupState`. + + For each group, all columns are passed together as `pandas.DataFrame` to the user-function, + and the returned `pandas.DataFrame` across all invocations are combined as a + :class:`DataFrame`. Note that the user function should not make a guess of the number of + elements in the iterator. To process all data, the user function needs to iterate all + elements and process them. On the other hand, the user function is not strictly required to + iterate through all elements in the iterator if it intends to read a part of data. + + The `outputStructType` should be a :class:`StructType` describing the schema of all + elements in the returned value, `pandas.DataFrame`. The column labels of all elements in + returned `pandas.DataFrame` must either match the field names in the defined schema if + specified as strings, or match the field data types by position if not strings, + e.g. integer indices. + + The `stateStructType` should be :class:`StructType` describing the schema of the + user-defined state. The value of the state will be presented as a tuple, as well as the + update should be performed with the tuple. The corresponding Python types for + :class:DataType are supported. Please refer to the page + https://spark.apache.org/docs/latest/sql-ref-datatypes.html (python tab). + + The size of each DataFrame in both the input and output can be arbitrary. The number of + DataFrames in both the input and output can also be arbitrary. + + .. versionadded:: 3.4.0 + + Parameters + ---------- + func : function + a Python native function to be called on every group. It should take parameters + (key, Iterator[`pandas.DataFrame`], state) and return Iterator[`pandas.DataFrame`]. + Note that the type of the key is tuple and the type of the state is + :class:`pyspark.sql.streaming.state.GroupState`. + outputStructType : :class:`pyspark.sql.types.DataType` or str + the type of the output records. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + stateStructType : :class:`pyspark.sql.types.DataType` or str + the type of the user-defined state. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + outputMode : str + the output mode of the function. + timeoutConf : str + timeout configuration for groups that do not receive data for a while. valid values + are defined in :class:`pyspark.sql.streaming.state.GroupStateTimeout`. + + Examples + -------- + >>> import pandas as pd # doctest: +SKIP + >>> from pyspark.sql.streaming.state import GroupStateTimeout + >>> def count_fn(key, pdf_iter, state): + ... assert isinstance(state, GroupStateImpl) + ... total_len = 0 + ... for pdf in pdf_iter: + ... total_len += len(pdf) + ... state.update((total_len,)) + ... yield pd.DataFrame({"id": [key[0]], "countAsString": [str(total_len)]}) + >>> df.groupby("id").applyInPandasWithState( + ... count_fn, outputStructType="id long, countAsString string", + ... stateStructType="len long", outputMode="Update", + ... timeoutConf=GroupStateTimeout.NoTimeout) # doctest: +SKIP + + Notes + ----- + This function requires a full shuffle. + + This API is experimental. + """ + + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + assert timeoutConf in [ + GroupStateTimeout.NoTimeout, + GroupStateTimeout.ProcessingTimeTimeout, + GroupStateTimeout.EventTimeTimeout, + ] + + if isinstance(outputStructType, str): + outputStructType = cast(StructType, _parse_datatype_string(outputStructType)) + if isinstance(stateStructType, str): + stateStructType = cast(StructType, _parse_datatype_string(stateStructType)) + + udf = pandas_udf( + func, # type: ignore[call-overload] + returnType=outputStructType, + functionType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + ) + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.applyInPandasWithState( + udf_column._jc.expr(), + self.session._jsparkSession.parseDataType(outputStructType.json()), + self.session._jsparkSession.parseDataType(stateStructType.json()), + outputMode, + timeoutConf, + ) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 992e82b403a..ca249c75ea5 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -19,7 +19,9 @@ Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details. """ -from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer +from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer +from pyspark.sql.pandas.types import to_arrow_type +from pyspark.sql.types import StringType, StructType, BinaryType, StructField, LongType class SpecialLengths: @@ -371,3 +373,354 @@ class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): raise ValueError( "Invalid number of pandas.DataFrames in group {0}".format(dataframes_in_group) ) + + +class ApplyInPandasWithStateSerializer(ArrowStreamPandasUDFSerializer): + """ + Serializer used by Python worker to evaluate UDF for applyInPandasWithState. + + Parameters + ---------- + timezone : str + A timezone to respect when handling timestamp values + safecheck : bool + If True, conversion from Arrow to Pandas checks for overflow/truncation + assign_cols_by_name : bool + If True, then Pandas DataFrames will get columns by name + state_object_schema : StructType + The type of state object represented as Spark SQL type + arrow_max_records_per_batch : int + Limit of the number of records that can be written to a single ArrowRecordBatch in memory. + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + arrow_max_records_per_batch, + ): + super(ApplyInPandasWithStateSerializer, self).__init__( + timezone, safecheck, assign_cols_by_name + ) + self.pickleSer = CPickleSerializer() + self.utf8_deserializer = UTF8Deserializer() + self.state_object_schema = state_object_schema + + self.result_state_df_type = StructType( + [ + StructField("properties", StringType()), + StructField("keyRowAsUnsafe", BinaryType()), + StructField("object", BinaryType()), + StructField("oldTimeoutTimestamp", LongType()), + ] + ) + + self.result_state_pdf_arrow_type = to_arrow_type(self.result_state_df_type) + self.arrow_max_records_per_batch = arrow_max_records_per_batch + + def load_stream(self, stream): + """ + Read ArrowRecordBatches from stream, deserialize them to populate a list of pair + (data chunk, state), and convert the data into a list of pandas.Series. + + Please refer the doc of inner function `gen_data_and_state` for more details how + this function works in overall. + + In addition, this function further groups the return of `gen_data_and_state` by the state + instance (same semantic as grouping by grouping key) and produces an iterator of data + chunks for each group, so that the caller can lazily materialize the data chunk. + """ + + import pyarrow as pa + import json + from itertools import groupby + from pyspark.sql.streaming.state import GroupState + + def construct_state(state_info_col): + """ + Construct state instance from the value of state information column. + """ + + state_info_col_properties = state_info_col["properties"] + state_info_col_key_row = state_info_col["keyRowAsUnsafe"] + state_info_col_object = state_info_col["object"] + + state_properties = json.loads(state_info_col_properties) + if state_info_col_object: + state_object = self.pickleSer.loads(state_info_col_object) + else: + state_object = None + state_properties["optionalValue"] = state_object + + return GroupState( + keyAsUnsafe=state_info_col_key_row, + valueSchema=self.state_object_schema, + **state_properties, + ) + + def gen_data_and_state(batches): + """ + Deserialize ArrowRecordBatches and return a generator of + `(a list of pandas.Series, state)`. + + The logic on deserialization is following: + + 1. Read the entire data part from Arrow RecordBatch. + 2. Read the entire state information part from Arrow RecordBatch. + 3. Loop through each state information: + 3.A. Extract the data out from entire data via the information of data range. + 3.B. Construct a new state instance if the state information is the first occurrence + for the current grouping key. + 3.C. Leverage the existing state instance if it is already available for the current + grouping key. (Meaning it's not the first occurrence.) + 3.D. Remove the cache of state instance if the state information denotes the data is + the last chunk for current grouping key. + + This deserialization logic assumes that Arrow RecordBatches contain the data with the + ordering that data chunks for same grouping key will appear sequentially. + + This function must avoid materializing multiple Arrow RecordBatches into memory at the + same time. And data chunks from the same grouping key should appear sequentially, to + further group them based on state instance (same state instance will be produced for + same grouping key). + """ + + state_for_current_group = None + + for batch in batches: + batch_schema = batch.schema + data_schema = pa.schema([batch_schema[i] for i in range(0, len(batch_schema) - 1)]) + state_schema = pa.schema( + [ + batch_schema[-1], + ] + ) + + batch_columns = batch.columns + data_columns = batch_columns[0:-1] + state_column = batch_columns[-1] + + data_batch = pa.RecordBatch.from_arrays(data_columns, schema=data_schema) + state_batch = pa.RecordBatch.from_arrays( + [ + state_column, + ], + schema=state_schema, + ) + + state_arrow = pa.Table.from_batches([state_batch]).itercolumns() + state_pandas = [self.arrow_to_pandas(c) for c in state_arrow][0] + + for state_idx in range(0, len(state_pandas)): + state_info_col = state_pandas.iloc[state_idx] + + if not state_info_col: + # no more data with grouping key + state + break + + data_start_offset = state_info_col["startOffset"] + num_data_rows = state_info_col["numRows"] + is_last_chunk = state_info_col["isLastChunk"] + + if state_for_current_group: + # use the state, we already have state for same group and there should be + # some data in same group being processed earlier + state = state_for_current_group + else: + # there is no state being stored for same group, construct one + state = construct_state(state_info_col) + + if is_last_chunk: + # discard the state being cached for same group + state_for_current_group = None + elif not state_for_current_group: + # there's no cached state but expected to have additional data in same group + # cache the current state + state_for_current_group = state + + data_batch_for_group = data_batch.slice(data_start_offset, num_data_rows) + data_arrow = pa.Table.from_batches([data_batch_for_group]).itercolumns() + + data_pandas = [self.arrow_to_pandas(c) for c in data_arrow] + + # state info + yield ( + data_pandas, + state, + ) + + _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + + data_state_generator = gen_data_and_state(_batches) + + # state will be same object for same grouping key + for _state, _data in groupby(data_state_generator, key=lambda x: x[1]): + yield ( + _data, + _state, + ) + + def dump_stream(self, iterator, stream): + """ + Read through an iterator of (iterator of pandas DataFrame, state), serialize them to Arrow + RecordBatches, and write batches to stream. + """ + + import pandas as pd + import pyarrow as pa + + def construct_state_pdf(state): + """ + Construct a pandas DataFrame from the state instance. + """ + + state_properties = state.json().encode("utf-8") + state_key_row_as_binary = state._keyAsUnsafe + if state.exists: + state_object = self.pickleSer.dumps(state._value_schema.toInternal(state._value)) + else: + state_object = None + state_old_timeout_timestamp = state.oldTimeoutTimestamp + + state_dict = { + "properties": [ + state_properties, + ], + "keyRowAsUnsafe": [ + state_key_row_as_binary, + ], + "object": [ + state_object, + ], + "oldTimeoutTimestamp": [ + state_old_timeout_timestamp, + ], + } + + return pd.DataFrame.from_dict(state_dict) + + def construct_record_batch(pdfs, pdf_data_cnt, pdf_schema, state_pdfs, state_data_cnt): + """ + Construct a new Arrow RecordBatch based on output pandas DataFrames and states. Each + one matches to the single struct field for Arrow schema, hence the return value of + Arrow RecordBatch will have schema with two fields, in `data`, `state` order. + (Readers are expected to access the field via position rather than the name. We do + not guarantee the name of the field.) + + Note that Arrow RecordBatch requires all columns to have all same number of rows, + hence this function inserts empty data for state/data with less elements to compensate. + """ + + max_data_cnt = max(pdf_data_cnt, state_data_cnt) + + empty_row_cnt_in_data = max_data_cnt - pdf_data_cnt + empty_row_cnt_in_state = max_data_cnt - state_data_cnt + + empty_rows_pdf = pd.DataFrame( + dict.fromkeys(pa.schema(pdf_schema).names), + index=[x for x in range(0, empty_row_cnt_in_data)], + ) + empty_rows_state = pd.DataFrame( + columns=["properties", "keyRowAsUnsafe", "object", "oldTimeoutTimestamp"], + index=[x for x in range(0, empty_row_cnt_in_state)], + ) + + pdfs.append(empty_rows_pdf) + state_pdfs.append(empty_rows_state) + + merged_pdf = pd.concat(pdfs, ignore_index=True) + merged_state_pdf = pd.concat(state_pdfs, ignore_index=True) + + return self._create_batch( + [(merged_pdf, pdf_schema), (merged_state_pdf, self.result_state_pdf_arrow_type)] + ) + + def serialize_batches(): + """ + Read through an iterator of (iterator of pandas DataFrame, state), and serialize them + to Arrow RecordBatches. + + This function does batching on constructing the Arrow RecordBatch; a batch will be + serialized to the Arrow RecordBatch when the total number of records exceeds the + configured threshold. + """ + # a set of variables for the state of current batch which will be converted to Arrow + # RecordBatch. + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + + return_schema = None + + for data in iterator: + # data represents the result of each call of user function + packaged_result = data[0] + + # There are two results from the call of user function: + # 1) iterator of pandas DataFrame (output) + # 2) updated state instance + pdf_iter = packaged_result[0][0] + state = packaged_result[0][1] + + # This is static and won't change across batches. + return_schema = packaged_result[1] + + for pdf in pdf_iter: + # We ignore empty pandas DataFrame. + if len(pdf) > 0: + pdf_data_cnt += len(pdf) + pdfs.append(pdf) + + # If the total number of records in current batch exceeds the configured + # threshold, time to construct the Arrow RecordBatch from the batch. + if pdf_data_cnt > self.arrow_max_records_per_batch: + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) + + # Reset the variables to start with new batch for further data. + pdfs = [] + state_pdfs = [] + pdf_data_cnt = 0 + state_data_cnt = 0 + + yield batch + + # This has to be performed 'after' evaluating all elements in iterator, so that + # the user function has been completed and the state is guaranteed to be updated. + state_pdf = construct_state_pdf(state) + + state_pdfs.append(state_pdf) + state_data_cnt += 1 + + # processed all output, but current batch may not be flushed yet. + if pdf_data_cnt > 0 or state_data_cnt > 0: + batch = construct_record_batch( + pdfs, pdf_data_cnt, return_schema, state_pdfs, state_data_cnt + ) + + yield batch + + def init_stream_yield_batches(batches): + """ + This function helps to ensure the requirement for Pandas UDFs - Pandas UDFs require a + START_ARROW_STREAM before the Arrow stream is sent. + + START_ARROW_STREAM should be sent after creating the first record batch so in case of + an error, it can be sent back to the JVM before the Arrow stream starts. + """ + should_write_start_length = True + + for batch in batches: + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + batches_to_write = init_stream_yield_batches(serialize_batches()) + + return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream) diff --git a/python/pyspark/sql/streaming/state.py b/python/pyspark/sql/streaming/state.py index 842eff32233..66b225e1b10 100644 --- a/python/pyspark/sql/streaming/state.py +++ b/python/pyspark/sql/streaming/state.py @@ -20,16 +20,24 @@ from typing import Tuple, Optional from pyspark.sql.types import DateType, Row, StructType -__all__ = ["GroupStateImpl", "GroupStateTimeout"] +__all__ = ["GroupState", "GroupStateTimeout"] class GroupStateTimeout: + """ + Represents the type of timeouts possible for the Dataset operations applyInPandasWithState. + """ + NoTimeout: str = "NoTimeout" ProcessingTimeTimeout: str = "ProcessingTimeTimeout" EventTimeTimeout: str = "EventTimeTimeout" -class GroupStateImpl: +class GroupState: + """ + Wrapper class for interacting with per-group state data in `applyInPandasWithState`. + """ + NO_TIMESTAMP: int = -1 def __init__( @@ -76,10 +84,16 @@ class GroupStateImpl: @property def exists(self) -> bool: + """ + Whether state exists or not. + """ return self._defined @property def get(self) -> Tuple: + """ + Get the state value if it exists, or throw ValueError. + """ if self.exists: return tuple(self._value) else: @@ -87,6 +101,9 @@ class GroupStateImpl: @property def getOption(self) -> Optional[Tuple]: + """ + Get the state value if it exists, or return None. + """ if self.exists: return tuple(self._value) else: @@ -94,6 +111,10 @@ class GroupStateImpl: @property def hasTimedOut(self) -> bool: + """ + Whether the function has been called because the key has timed out. + This can return true only when timeouts are enabled. + """ return self._has_timed_out # NOTE: this function is only available to PySpark implementation due to underlying @@ -103,6 +124,9 @@ class GroupStateImpl: return self._old_timeout_timestamp def update(self, newValue: Tuple) -> None: + """ + Update the value of the state. The value of the state cannot be null. + """ if newValue is None: raise ValueError("'None' is not a valid state value") @@ -112,11 +136,18 @@ class GroupStateImpl: self._removed = False def remove(self) -> None: + """ + Remove this state. + """ self._defined = False self._updated = False self._removed = True def setTimeoutDuration(self, durationMs: int) -> None: + """ + Set the timeout duration in ms for this key. + Processing time timeout must be enabled. + """ if isinstance(durationMs, str): # TODO(SPARK-40437): Support string representation of durationMs. raise ValueError("durationMs should be int but get :%s" % type(durationMs)) @@ -133,6 +164,11 @@ class GroupStateImpl: # TODO(SPARK-40438): Implement additionalDuration parameter. def setTimeoutTimestamp(self, timestampMs: int) -> None: + """ + Set the timeout timestamp for this key as milliseconds in epoch time. + This timestamp cannot be older than the current watermark. + Event time timeout must be enabled. + """ if self._timeout_conf != GroupStateTimeout.EventTimeTimeout: raise RuntimeError( "Cannot set timeout duration without enabling processing time timeout in " @@ -146,7 +182,7 @@ class GroupStateImpl: raise ValueError("Timeout timestamp must be positive") if ( - self._event_time_watermark_ms != GroupStateImpl.NO_TIMESTAMP + self._event_time_watermark_ms != GroupState.NO_TIMESTAMP and timestampMs < self._event_time_watermark_ms ): raise ValueError( @@ -157,6 +193,10 @@ class GroupStateImpl: self._timeout_timestamp = timestampMs def getCurrentWatermarkMs(self) -> int: + """ + Get the current event time watermark as milliseconds in epoch time. + In a streaming query, this can be called only when watermark is set. + """ if not self._watermark_present: raise RuntimeError( "Cannot get event time watermark timestamp without setting watermark before " @@ -165,6 +205,11 @@ class GroupStateImpl: return self._event_time_watermark_ms def getCurrentProcessingTimeMs(self) -> int: + """ + Get the current processing time as milliseconds in epoch time. + In a streaming query, this will return a constant value throughout the duration of a + trigger, even if the trigger is re-executed. + """ return self._batch_processing_time_ms def __str__(self) -> str: @@ -174,6 +219,10 @@ class GroupStateImpl: return "GroupState(<undefined>)" def json(self) -> str: + """ + Convert the internal values of instance into JSON. This is used to send out the update + from Python worker to executor. + """ return json.dumps( { # Constructor diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 6a01e399d04..da9a245bb71 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -144,20 +144,23 @@ class UserDefinedFunction: "Invalid return type with scalar Pandas UDFs: %s is " "not supported" % str(self._returnType_placeholder) ) - elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + elif ( + self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + or self.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE + ): if isinstance(self._returnType_placeholder, StructType): try: to_arrow_type(self._returnType_placeholder) except TypeError: raise NotImplementedError( "Invalid return type with grouped map Pandas UDFs or " - "at groupby.applyInPandas: %s is not supported" + "at groupby.applyInPandas(WithState): %s is not supported" % str(self._returnType_placeholder) ) else: raise TypeError( "Invalid return type for grouped map Pandas " - "UDFs or at groupby.applyInPandas: return type must be a " + "UDFs or at groupby.applyInPandas(WithState): return type must be a " "StructType." ) elif ( diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c486b7bed1d..c1c3669701f 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,7 @@ import sys import time from inspect import currentframe, getframeinfo, getfullargspec import importlib +import json # 'resource' is a Unix specific module. has_resource_module = True @@ -57,6 +58,7 @@ from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, CogroupUDFSerializer, ArrowStreamUDFSerializer, + ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import StructType @@ -207,6 +209,90 @@ def wrap_grouped_map_pandas_udf(f, return_type, argspec): return lambda k, v: [(wrapped(k, v), to_arrow_type(return_type))] +def wrap_grouped_map_pandas_udf_with_state(f, return_type): + """ + Provides a new lambda instance wrapping user function of applyInPandasWithState. + + The lambda instance receives (key series, iterator of value series, state) and performs + some conversion to be adapted with the signature of user function. + + See the function doc of inner function `wrapped` for more details on what adapter does. + See the function doc of `mapper` function for + `eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE` for more details on + the input parameters of lambda function. + + Along with the returned iterator, the lambda instance will also produce the return_type as + converted to the arrow schema. + """ + + def wrapped(key_series, value_series_gen, state): + """ + Provide an adapter of the user function performing below: + + - Extract the first value of all columns in key series and produce as a tuple. + - If the state has timed out, call the user function with empty pandas DataFrame. + - If not, construct a new generator which converts each element of value series to + pandas DataFrame (lazy evaluation), and call the user function with the generator + - Verify each element of returned iterator to check the schema of pandas DataFrame. + """ + import pandas as pd + + key = tuple(s[0] for s in key_series) + + if state.hasTimedOut: + # Timeout processing pass empty iterator. Here we return an empty DataFrame instead. + values = [ + pd.DataFrame(columns=pd.concat(next(value_series_gen), axis=1).columns), + ] + else: + values = (pd.concat(x, axis=1) for x in value_series_gen) + + result_iter = f(key, values, state) + + def verify_element(result): + if not isinstance(result, pd.DataFrame): + raise TypeError( + "The type of element in return iterator of the user-defined function " + "should be pandas.DataFrame, but is {}".format(type(result)) + ) + # the number of columns of result have to match the return type + # but it is fine for result to have no columns at all if it is empty + if not ( + len(result.columns) == len(return_type) + or (len(result.columns) == 0 and result.empty) + ): + raise RuntimeError( + "Number of columns of the element (pandas.DataFrame) in return iterator " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns)) + ) + + return result + + if isinstance(result_iter, pd.DataFrame): + raise TypeError( + "Return type of the user-defined function should be " + "iterable of pandas.DataFrame, but is {}".format(type(result_iter)) + ) + + try: + iter(result_iter) + except TypeError: + raise TypeError( + "Return type of the user-defined function should be " + "iterable, but is {}".format(type(result_iter)) + ) + + result_iter_with_validation = (verify_element(x) for x in result_iter) + + return ( + result_iter_with_validation, + state, + ) + + return lambda k, v, s: [(wrapped(k, v, s), to_arrow_type(return_type))] + + def wrap_grouped_agg_pandas_udf(f, return_type): arrow_return_type = to_arrow_type(return_type) @@ -311,6 +397,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec) @@ -336,6 +424,7 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, ): # Load conf used for pandas_udf evaluation @@ -345,6 +434,10 @@ def read_udfs(pickleSer, infile, eval_type): v = utf8_deserializer.loads(infile) runner_conf[k] = v + state_object_schema = None + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True timezone = runner_conf.get("spark.sql.session.timeZone", None) safecheck = ( @@ -361,6 +454,19 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + arrow_max_records_per_batch = runner_conf.get( + "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 + ) + arrow_max_records_per_batch = int(arrow_max_records_per_batch) + + ser = ApplyInPandasWithStateSerializer( + timezone, + safecheck, + assign_cols_by_name, + state_object_schema, + arrow_max_records_per_batch, + ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() else: @@ -486,6 +592,43 @@ def read_udfs(pickleSer, infile, eval_type): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def mapper(a): + """ + The function receives (iterator of data, state) and performs extraction of key and + value from the data, with retaining lazy evaluation. + + See `load_stream` in `ApplyInPandasWithStateSerializer` for more details on the input + and see `wrap_grouped_map_pandas_udf_with_state` for more details on how output will + be used. + """ + from itertools import tee + + state = a[1] + data_gen = (x[0] for x in a[0]) + + # We know there should be at least one item in the iterator/generator. + # We want to peek the first element to construct the key, hence applying + # tee to construct the key while we retain another iterator/generator + # for values. + keys_gen, values_gen = tee(data_gen) + keys_elem = next(keys_gen) + keys = [keys_elem[o] for o in parsed_offsets[0][0]] + + # This must be generator comprehension - do not materialize. + vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen) + + return f(keys, vals, state) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: # We assume there is only one UDF here because cogrouped map doesn't # support combining multiple UDFs. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c11ce7d3b90..84795203fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -64,6 +64,7 @@ object UnsupportedOperationChecker extends Logging { case s: Aggregate if s.isStreaming => true case _ @ Join(left, right, _, _, _) if left.isStreaming && right.isStreaming => true case f: FlatMapGroupsWithState if f.isStreaming => true + case f: FlatMapGroupsInPandasWithState if f.isStreaming => true case d: Deduplicate if d.isStreaming => true case _ => false } @@ -142,6 +143,17 @@ object UnsupportedOperationChecker extends Logging { " or the output mode is not append on a streaming DataFrames/Datasets")(plan) } + val applyInPandasWithStates = plan.collect { + case f: FlatMapGroupsInPandasWithState if f.isStreaming => f + } + + // Disallow multiple `applyInPandasWithState`s. + if (applyInPandasWithStates.size > 1) { + throwError( + "Multiple applyInPandasWithStates are not supported on a streaming " + + "DataFrames/Datasets")(plan) + } + // Disallow multiple streaming aggregations val aggregates = collectStreamingAggregates(plan) @@ -311,6 +323,56 @@ object UnsupportedOperationChecker extends Logging { } } + // applyInPandasWithState + case m: FlatMapGroupsInPandasWithState if m.isStreaming => + // Check compatibility with output modes and aggregations in query + val aggsInQuery = collectStreamingAggregates(plan) + + if (aggsInQuery.isEmpty) { + // applyInPandasWithState without aggregation: operation's output mode must + // match query output mode + m.outputMode match { + case InternalOutputModes.Update if outputMode != InternalOutputModes.Update => + throwError( + "applyInPandasWithState in update mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case InternalOutputModes.Append if outputMode != InternalOutputModes.Append => + throwError( + "applyInPandasWithState in append mode is not supported with " + + s"$outputMode output mode on a streaming DataFrame/Dataset") + + case _ => + } + } else { + // applyInPandasWithState with aggregation: update operation mode not allowed, and + // *groupsWithState after aggregation not allowed + if (m.outputMode == InternalOutputModes.Update) { + throwError( + "applyInPandasWithState in update mode is not supported with " + + "aggregation on a streaming DataFrame/Dataset") + } else if (collectStreamingAggregates(m).nonEmpty) { + throwError( + "applyInPandasWithState in append mode is not supported after " + + "aggregation on a streaming DataFrame/Dataset") + } + } + + // Check compatibility with timeout configs + if (m.timeout == EventTimeTimeout) { + // With event time timeout, watermark must be defined. + val watermarkAttributes = m.child.output.collect { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => a + } + if (watermarkAttributes.isEmpty) { + throwError( + "Watermark must be specified in the query using " + + "'[Dataset/DataFrame].withWatermark()' for using event-time timeout in a " + + "applyInPandasWithState. Event-time timeout not supported without " + + "watermark.")(plan) + } + } + case d: Deduplicate if collectStreamingAggregates(d).nonEmpty => throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index c2f74b35083..e97ff7808f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType /** * FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame. @@ -98,6 +100,38 @@ case class FlatMapCoGroupsInPandas( copy(left = newLeft, right = newRight) } +/** + * Similar with [[FlatMapGroupsWithState]]. Applies func to each unique group + * in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * `functionExpr` is invoked with an pandas DataFrame representation and the + * grouping key (tuple). + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outputAttrs used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param outputMode the output mode of `func` + * @param timeout used to timeout groups that have not received data in a while + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithState( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outputAttrs: Seq[Attribute], + stateType: StructType, + outputMode: OutputMode, + timeout: GroupStateTimeout, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = outputAttrs + + override def producedAttributes: AttributeSet = AttributeSet(outputAttrs) + + override protected def withNewChildInternal( + newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 989ee325218..0429fd27a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -30,9 +30,11 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{NumericType, StructType} /** @@ -620,6 +622,49 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a grouped vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: iterator of `pandas.DataFrame` -> + * iterator of `pandas.DataFrame`. + * For each group, all elements in the group are passed as an iterator of `pandas.DataFrame` + * along with corresponding state, and the results for all groups are combined into a new + * [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def applyInPandasWithState( + func: PythonUDF, + outputStructType: StructType, + stateStructType: StructType, + outputModeStr: String, + timeoutConfStr: String): DataFrame = { + val timeoutConf = org.apache.spark.sql.execution.streaming + .GroupStateImpl.groupStateTimeoutFromString(timeoutConfStr) + val outputMode = InternalOutputModes(outputModeStr) + if (outputMode != OutputMode.Append && outputMode != OutputMode.Update) { + throw new IllegalArgumentException("The output mode of function should be append or update") + } + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) + val outputAttrs = outputStructType.toAttributes + val plan = FlatMapGroupsInPandasWithState( + func, + groupingAttrs, + outputAttrs, + stateStructType, + outputMode, + timeoutConf, + child = df.logicalPlan) + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6104104c7be..c64a123e3a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -684,6 +684,25 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert [[FlatMapGroupsInPandasWithState]] logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ + object FlatMapGroupsInPandasWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case FlatMapGroupsInPandasWithState( + func, groupAttr, outputAttr, stateType, outputMode, timeout, child) => + val stateVersion = conf.getConf(SQLConf.FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION) + val execPlan = python.FlatMapGroupsInPandasWithStateExec( + func, groupAttr, outputAttr, stateType, None, stateVersion, outputMode, timeout, + batchTimestampMs = None, eventTimeWatermark = None, planLater(child) + ) + execPlan :: Nil + case _ => + Nil + } + } + /** * Strategy to convert EvalPython logical operator to physical operator. */ @@ -793,6 +812,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { initialStateGroupAttrs, data, initialStateDataAttrs, output, timeout, hasInitialState, planLater(initialState), planLater(child) ) :: Nil + case _: FlatMapGroupsInPandasWithState => + // TODO(SPARK-40443): support applyInPandasWithState in batch query + throw new UnsupportedOperationException( + "applyInPandasWithState is unsupported in batch query. Use applyInPandas instead.") case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 7abca5f0e33..2988c0fb518 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -44,7 +44,7 @@ object ArrowWriter { new ArrowWriter(root, children.toArray) } - private def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { + private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = { val field = vector.getField() (ArrowUtils.fromArrowField(field), vector) match { case (BooleanType, vector: BitVector) => new BooleanWriter(vector) @@ -98,6 +98,16 @@ class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { count += 1 } + def sizeInBytes(): Int = { + var i = 0 + var bytes = 0 + while (i < fields.size) { + bytes += fields(i).getSizeInBytes() + i += 1 + } + bytes + } + def finish(): Unit = { root.setRowCount(count) fields.foreach(_.finish()) @@ -132,6 +142,10 @@ private[arrow] abstract class ArrowFieldWriter { count += 1 } + def getSizeInBytes(): Int = { + valueVector.getBufferSizeFor(count) + } + def finish(): Unit = { valueVector.setValueCount(count) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala new file mode 100644 index 00000000000..bd8c72029dc --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.api.python._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} +import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + + +/** + * A variant implementation of [[ArrowPythonRunner]] to serve the operation + * applyInPandasWithState. + * + * Unlike normal ArrowPythonRunner which both input and output (executor <-> python worker) + * are InternalRow, applyInPandasWithState has side data (state information) in both input + * and output along with data, which requires different struct on Arrow RecordBatch. + */ +class ApplyInPandasWithStatePythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + inputSchema: StructType, + override protected val timeZoneId: String, + initialWorkerConf: Map[String, String], + stateEncoder: ExpressionEncoder[Row], + keySchema: StructType, + outputSchema: StructType, + stateValueSchema: StructType) + extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets) + with PythonArrowInput[InType] + with PythonArrowOutput[OutType] { + + private val sqlConf = SQLConf.get + + override protected val schema: StructType = inputSchema.add("__state", STATE_METADATA_SCHEMA) + + override val simplifiedTraceback: Boolean = sqlConf.pysparkSimplifiedTraceback + + override val bufferSize: Int = { + val configuredSize = sqlConf.pandasUDFBufferSize + if (configuredSize < 4) { + logWarning("Pandas execution requires more than 4 bytes. Please configure bigger value " + + s"for the configuration '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'. " + + "Force using the value '4'.") + 4 + } else { + configuredSize + } + } + + private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch + + // applyInPandasWithState has its own mechanism to construct the Arrow RecordBatch instance. + // Configurations are both applied to executor and Python worker, set them to the worker conf + // to let Python worker read the config properly. + override protected val workerConf: Map[String, String] = initialWorkerConf + + (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> arrowMaxRecordsPerBatch.toString) + + private val stateRowDeserializer = stateEncoder.createDeserializer() + + /** + * This method sends out the additional metadata before sending out actual data. + * + * Specifically, this class overrides this method to also write the schema for state value. + */ + override protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { + super.handleMetadataBeforeExec(stream) + // Also write the schema for state value + PythonRDD.writeUTF(stateValueSchema.json, stream) + } + + /** + * Read the (key, state, values) from input iterator and construct Arrow RecordBatches, and + * write constructed RecordBatches to the writer. + * + * See [[ApplyInPandasWithStateWriter]] for more details. + */ + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[InType]): Unit = { + val w = new ApplyInPandasWithStateWriter(root, writer, arrowMaxRecordsPerBatch) + + while (inputIterator.hasNext) { + val (keyRow, groupState, dataIter) = inputIterator.next() + assert(dataIter.hasNext, "should have at least one data row!") + w.startNewGroup(keyRow, groupState) + + while (dataIter.hasNext) { + val dataRow = dataIter.next() + w.writeRow(dataRow) + } + + w.finalizeGroup() + } + + w.finalizeData() + } + + /** + * Deserialize ColumnarBatch received from the Python worker to produce the output. Schema info + * for given ColumnarBatch is also provided as well. + */ + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OutType = { + // This should at least have one row for state. Also, we ensure that all columns across + // data and state metadata have same number of rows, which is required by Arrow record + // batch. + assert(batch.numRows() > 0) + assert(schema.length == 2) + + def getColumnarBatchForStructTypeColumn( + batch: ColumnarBatch, + ordinal: Int, + expectedType: StructType): ColumnarBatch = { + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] + val dataType = schema(ordinal).dataType.asInstanceOf[StructType] + assert(dataType.sameType(expectedType), + s"Schema equality check failure! type from Arrow: $dataType, expected type: $expectedType") + + val outputVectors = dataType.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + + flattenedBatch + } + + def constructIterForData(batch: ColumnarBatch): Iterator[InternalRow] = { + val dataBatch = getColumnarBatchForStructTypeColumn(batch, 0, outputSchema) + dataBatch.rowIterator.asScala.flatMap { row => + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for state metadata. + None + } else { + Some(row) + } + } + } + + def constructIterForState(batch: ColumnarBatch): Iterator[OutTypeForState] = { + val stateMetadataBatch = getColumnarBatchForStructTypeColumn(batch, 1, + STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER) + + stateMetadataBatch.rowIterator().asScala.flatMap { row => + implicit val formats = org.json4s.DefaultFormats + + if (row.isNullAt(0)) { + // The entire row in record batch seems to be for data. + None + } else { + // NOTE: See ApplyInPandasWithStatePythonRunner.STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER + // for the schema. + val propertiesAsJson = parse(row.getUTF8String(0).toString) + val keyRowAsUnsafeAsBinary = row.getBinary(1) + val keyRowAsUnsafe = new UnsafeRow(keySchema.fields.length) + keyRowAsUnsafe.pointTo(keyRowAsUnsafeAsBinary, keyRowAsUnsafeAsBinary.length) + val maybeObjectRow = if (row.isNullAt(2)) { + None + } else { + val pickledStateValue = row.getBinary(2) + Some(PythonSQLUtils.toJVMRow(pickledStateValue, stateValueSchema, + stateRowDeserializer)) + } + val oldTimeoutTimestamp = row.getLong(3) + + Some((keyRowAsUnsafe, GroupStateImpl.fromJson(maybeObjectRow, propertiesAsJson), + oldTimeoutTimestamp)) + } + } + } + + (constructIterForState(batch), constructIterForData(batch)) + } +} + +object ApplyInPandasWithStatePythonRunner { + type InType = (UnsafeRow, GroupStateImpl[Row], Iterator[InternalRow]) + type OutTypeForState = (UnsafeRow, GroupStateImpl[Row], Long) + type OutType = (Iterator[OutTypeForState], Iterator[InternalRow]) + + val STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER: StructType = StructType( + Array( + StructField("properties", StringType), + StructField("keyRowAsUnsafe", BinaryType), + StructField("object", BinaryType), + StructField("oldTimeoutTimestamp", LongType) + ) + ) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala new file mode 100644 index 00000000000..60a228ddd73 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStateWriter.scala @@ -0,0 +1,276 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot} +import org.apache.arrow.vector.ipc.ArrowStreamWriter + +import org.apache.spark.sql.Row +import org.apache.spark.sql.api.python.PythonSQLUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.execution.arrow.ArrowWriter.createFieldWriter +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.types.{BinaryType, BooleanType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This class abstracts the complexity on constructing Arrow RecordBatches for data and state with + * bin-packing and chunking. The caller only need to call the proper public methods of this class + * `startNewGroup`, `writeRow`, `finalizeGroup`, `finalizeData` and this class will write the data + * and state into Arrow RecordBatches with performing bin-pack and chunk internally. + * + * This class requires that the parameter `root` has been initialized with the Arrow schema like + * below: + * - data fields + * - state field + * - nested schema (Refer ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA) + * + * Please refer the code comment in the implementation to see how the writes of data and state + * against Arrow RecordBatch work with consideration of bin-packing and chunking. + */ +class ApplyInPandasWithStateWriter( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + arrowMaxRecordsPerBatch: Int) { + + import ApplyInPandasWithStateWriter._ + + // Unlike applyInPandas (and other PySpark operators), applyInPandasWithState requires to produce + // the additional data `state`, along with the input data. + // + // ArrowStreamWriter supports only single VectorSchemaRoot, which means all Arrow RecordBatches + // being sent out from ArrowStreamWriter should have same schema. That said, we have to construct + // "an" Arrow schema to contain both data and state, and also construct ArrowBatches to contain + // both data and state. + // + // To achieve this, we extend the schema for input data to have a column for state at the end. + // But also, we logically group the columns by family (data vs state) and initialize writer + // separately, since it's lot more easier and probably performant to write the row directly + // rather than projecting the row to match up with the overall schema. + // + // Although Arrow RecordBatch enables to write the data as columnar, we figure out it gives + // strange outputs if we don't ensure that all columns have the same number of values. Since + // there are at least one data for a grouping key (we ensure this for the case of handling timed + // out state as well) whereas there is only one state for a grouping key, we have to fill up the + // empty rows in state side to ensure both have the same number of rows. + private val arrowWriterForData = createArrowWriter( + root.getFieldVectors.asScala.toSeq.dropRight(1)) + private val arrowWriterForState = createArrowWriter( + root.getFieldVectors.asScala.toSeq.takeRight(1)) + + // - Bin-packing + // + // We apply bin-packing the data from multiple groups into one Arrow RecordBatch to + // gain the performance. In many cases, the amount of data per grouping key is quite + // small, which does not seem to maximize the benefits of using Arrow. + // + // We have to split the record batch down to each group in Python worker to convert the + // data for group to Pandas, but hopefully, Arrow RecordBatch provides the way to split + // the range of data and give a view, say, "zero-copy". To help splitting the range for + // data, we provide the "start offset" and the "number of data" in the state metadata. + // + // We don't bin-pack all groups into a single record batch - we have a limit on the number + // of rows in the current Arrow RecordBatch to stop adding next group. + // + // - Chunking + // + // We also chunk the data from single group into multiple Arrow RecordBatch to ensure + // scalability. Note that we don't know the volume (number of rows, overall size) of data for + // specific group key before we read the entire data. The easiest approach to address both + // bin-pack and chunk is to check the number of rows in the current Arrow RecordBatch for each + // write of row. + // + // - Data and State + // + // Since we apply bin-packing and chunking, there should be the way to distinguish each chunk + // from the entire data part of Arrow RecordBatch. We leverage the state metadata to also + // contain the "metadata" of data part to distinguish the chunk from the entire data. + // As a result, state metadata has a 1-1 relationship with "chunk", instead of "grouping key". + // + // - Consideration + // + // Since the number of rows in Arrow RecordBatch does not represent the actual size (bytes), + // the limit should be set very conservatively. Using a small number of limit does not introduce + // correctness issues. + + // variables for tracking current grouping key and state + private var currentGroupKeyRow: UnsafeRow = _ + private var currentGroupState: GroupStateImpl[Row] = _ + + // variables for tracking the status of current batch + private var totalNumRowsForBatch = 0 + private var totalNumStatesForBatch = 0 + + // variables for tracking the status of current chunk + private var startOffsetForCurrentChunk = 0 + private var numRowsForCurrentChunk = 0 + + + /** + * Indicates writer to start with new grouping key. + * + * @param keyRow The grouping key row for current group. + * @param groupState The instance of GroupStateImpl for current group. + */ + def startNewGroup(keyRow: UnsafeRow, groupState: GroupStateImpl[Row]): Unit = { + currentGroupKeyRow = keyRow + currentGroupState = groupState + } + + /** + * Indicates writer to write a row in the current group. + * + * @param dataRow The row to write in the current group. + */ + def writeRow(dataRow: InternalRow): Unit = { + // If it exceeds the condition of batch (number of records) and there is more data for the + // same group, finalize and construct a new batch. + + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentChunk(isLastChunkForGroup = false) + finalizeCurrentArrowBatch() + } + + arrowWriterForData.write(dataRow) + + numRowsForCurrentChunk += 1 + totalNumRowsForBatch += 1 + } + + /** + * Indicates writer that current group has finalized and there will be no further row bound to + * the current group. + */ + def finalizeGroup(): Unit = { + finalizeCurrentChunk(isLastChunkForGroup = true) + + // If it exceeds the condition of batch (number of records) once the all data is received for + // same group, finalize and construct a new batch. + if (totalNumRowsForBatch >= arrowMaxRecordsPerBatch) { + finalizeCurrentArrowBatch() + } + } + + /** + * Indicates writer that all groups have been processed. + */ + def finalizeData(): Unit = { + if (totalNumRowsForBatch > 0) { + // We still have some rows in the current record batch. Need to finalize them as well. + finalizeCurrentArrowBatch() + } + } + + private def createArrowWriter(fieldVectors: Seq[FieldVector]): ArrowWriter = { + val children = fieldVectors.map { vector => + vector.allocateNew() + createFieldWriter(vector) + } + + new ArrowWriter(root, children.toArray) + } + + private def buildStateInfoRow( + keyRow: UnsafeRow, + groupState: GroupStateImpl[Row], + startOffset: Int, + numRows: Int, + isLastChunk: Boolean): InternalRow = { + // NOTE: see ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA + val stateUnderlyingRow = new GenericInternalRow( + Array[Any]( + UTF8String.fromString(groupState.json()), + keyRow.getBytes, + groupState.getOption.map(PythonSQLUtils.toPyRow).orNull, + startOffset, + numRows, + isLastChunk + ) + ) + new GenericInternalRow(Array[Any](stateUnderlyingRow)) + } + + private def finalizeCurrentChunk(isLastChunkForGroup: Boolean): Unit = { + val stateInfoRow = buildStateInfoRow(currentGroupKeyRow, currentGroupState, + startOffsetForCurrentChunk, numRowsForCurrentChunk, isLastChunkForGroup) + arrowWriterForState.write(stateInfoRow) + totalNumStatesForBatch += 1 + + // The start offset for next chunk would be same as the total number of rows for batch, + // unless the next chunk starts with new batch. + startOffsetForCurrentChunk = totalNumRowsForBatch + numRowsForCurrentChunk = 0 + } + + private def finalizeCurrentArrowBatch(): Unit = { + val remainingEmptyStateRows = totalNumRowsForBatch - totalNumStatesForBatch + (0 until remainingEmptyStateRows).foreach { _ => + arrowWriterForState.write(EMPTY_STATE_METADATA_ROW) + } + + arrowWriterForState.finish() + arrowWriterForData.finish() + writer.writeBatch() + arrowWriterForState.reset() + arrowWriterForData.reset() + + startOffsetForCurrentChunk = 0 + numRowsForCurrentChunk = 0 + totalNumRowsForBatch = 0 + totalNumStatesForBatch = 0 + } +} + +object ApplyInPandasWithStateWriter { + // This schema contains both state metadata and the metadata of the chunk. Refer the code comment + // of "Data and State" for more details. + val STATE_METADATA_SCHEMA: StructType = StructType( + Array( + /* + Metadata of the state + */ + + // properties of state instance (excluding state value) in json format + StructField("properties", StringType), + // key row as UnsafeRow, Python worker won't touch this value but send the value back to + // executor when sending an update of state + StructField("keyRowAsUnsafe", BinaryType), + // state value + StructField("object", BinaryType), + + /* + Metadata of the chunk + */ + + // start offset of the data chunk from entire data + StructField("startOffset", IntegerType), + // the number of rows for the data chunk + StructField("numRows", IntegerType), + // whether the current data chunk is the last one for current grouping key or not + StructField("isLastChunk", BooleanType) + ) + ) + + // To avoid initializing a new row for empty state metadata row. + val EMPTY_STATE_METADATA_ROW = new GenericInternalRow( + Array[Any](null, null, null, null, null, null)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index e830ea6b546..b39787b12a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -78,8 +78,8 @@ case class FlatMapCoGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right, rightGroup) + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty left.execute().zipPartitions(right.execute()) { (leftData, rightData) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 3a3a6022f99..f0e815e966e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -75,7 +75,7 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val (dedupAttributes, argOffsets) = resolveArgOffsets(child, groupingAttributes) + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala new file mode 100644 index 00000000000..159f805f734 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasWithStateExec.scala @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout, ProcessingTimeTimeout} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils.resolveArgOffsets +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper.StateData +import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.CompletionIterator + +/** + * Physical operator for executing + * [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandasWithState]] + * + * @param functionExpr function called on each group + * @param groupingAttributes used to group the data + * @param outAttributes used to define the output rows + * @param stateType used to serialize/deserialize state before calling `functionExpr` + * @param stateInfo `StatefulOperatorStateInfo` to identify the state store for a given operator. + * @param stateFormatVersion the version of state format. + * @param outputMode the output mode of `functionExpr` + * @param timeoutConf used to timeout groups that have not received data in a while + * @param batchTimestampMs processing timestamp of the current batch. + * @param eventTimeWatermark event time watermark for the current batch + * @param child logical plan of the underlying data + */ +case class FlatMapGroupsInPandasWithStateExec( + functionExpr: Expression, + groupingAttributes: Seq[Attribute], + outAttributes: Seq[Attribute], + stateType: StructType, + stateInfo: Option[StatefulOperatorStateInfo], + stateFormatVersion: Int, + outputMode: OutputMode, + timeoutConf: GroupStateTimeout, + batchTimestampMs: Option[Long], + eventTimeWatermark: Option[Long], + child: SparkPlan) extends UnaryExecNode with FlatMapGroupsWithStateExecBase { + + // TODO(SPARK-40444): Add the support of initial state. + override protected val initialStateDeserializer: Expression = null + override protected val initialStateGroupAttrs: Seq[Attribute] = null + override protected val initialStateDataAttrs: Seq[Attribute] = null + override protected val initialState: SparkPlan = null + override protected val hasInitialState: Boolean = false + + override protected val stateEncoder: ExpressionEncoder[Any] = + RowEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]] + + override def output: Seq[Attribute] = outAttributes + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + private val pythonFunction = functionExpr.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private lazy val (dedupAttributes, argOffsets) = resolveArgOffsets( + groupingAttributes ++ child.output, groupingAttributes) + private lazy val unsafeProj = UnsafeProjection.create(dedupAttributes, child.output) + + override def requiredChildDistribution: Seq[Distribution] = + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, getStateInfo, conf) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending))) + + override def shortName: String = "applyInPandasWithState" + + override protected def withNewChildInternal( + newChild: SparkPlan): FlatMapGroupsInPandasWithStateExec = copy(child = newChild) + + override def createInputProcessor( + store: StateStore): InputProcessor = new InputProcessor(store: StateStore) { + + override def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) + val processIter = groupedIter.map { case (keyRow, valueRowIter) => + val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] + val stateData = stateManager.getState(store, keyUnsafeRow) + (keyUnsafeRow, stateData, valueRowIter.map(unsafeProj)) + } + + process(processIter, hasTimedOut = false) + } + + override def processNewDataWithInitialState( + childDataIter: Iterator[InternalRow], + initStateIter: Iterator[InternalRow]): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + + override def processTimedOutState(): Iterator[InternalRow] = { + if (isTimeoutEnabled) { + val timeoutThreshold = timeoutConf match { + case ProcessingTimeTimeout => batchTimestampMs.get + case EventTimeTimeout => eventTimeWatermark.get + case _ => + throw new IllegalStateException( + s"Cannot filter timed out keys for $timeoutConf") + } + val timingOutPairs = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold + } + + val processIter = timingOutPairs.map { stateData => + val joinedKeyRow = unsafeProj( + new JoinedRow( + stateData.keyRow, + new GenericInternalRow(Array.fill(dedupAttributes.length)(null: Any)))) + + (stateData.keyRow, stateData, Iterator.single(joinedKeyRow)) + } + + process(processIter, hasTimedOut = true) + } else Iterator.empty + } + + private def process( + iter: Iterator[(UnsafeRow, StateData, Iterator[InternalRow])], + hasTimedOut: Boolean): Iterator[InternalRow] = { + val runner = new ApplyInPandasWithStatePythonRunner( + chainedFunc, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + Array(argOffsets), + StructType.fromAttributes(dedupAttributes), + sessionLocalTimeZone, + pythonRunnerConf, + stateEncoder.asInstanceOf[ExpressionEncoder[Row]], + groupingAttributes.toStructType, + outAttributes.toStructType, + stateType) + + val context = TaskContext.get() + + val processIter = iter.map { case (keyRow, stateData, valueIter) => + val groupedState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj).map { r => assert(r.isInstanceOf[Row]); r }, + batchTimestampMs.getOrElse(NO_TIMESTAMP), + eventTimeWatermark.getOrElse(NO_TIMESTAMP), + timeoutConf, + hasTimedOut = hasTimedOut, + watermarkPresent).asInstanceOf[GroupStateImpl[Row]] + (keyRow, groupedState, valueIter) + } + runner.compute(processIter, context.partitionId(), context).flatMap { + case (stateIter, outputIter) => + // When the iterator is consumed, then write changes to state. + // state does not affect each others, hence when to update does not affect to the result. + def onIteratorCompletion: Unit = { + stateIter.foreach { case (keyRow, newGroupState, oldTimeoutTimestamp) => + if (newGroupState.isRemoved && !newGroupState.getTimeoutTimestampMs.isPresent()) { + stateManager.removeState(store, keyRow) + numRemovedStateRows += 1 + } else { + val currentTimeoutTimestamp = newGroupState.getTimeoutTimestampMs + .orElse(NO_TIMESTAMP) + val hasTimeoutChanged = currentTimeoutTimestamp != oldTimeoutTimestamp + val shouldWriteState = newGroupState.isUpdated || newGroupState.isRemoved || + hasTimeoutChanged + + if (shouldWriteState) { + val updatedStateObj = if (newGroupState.exists) newGroupState.get else null + stateManager.putState(store, keyRow, updatedStateObj, + currentTimeoutTimestamp) + numUpdatedStateRows += 1 + } + } + } + } + + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIter, onIteratorCompletion).map { row => + numOutputRows += 1 + row + } + } + } + + override protected def callFunctionAndUpdateState( + stateData: StateData, + valueRowIter: Iterator[InternalRow], + hasTimedOut: Boolean): Iterator[InternalRow] = { + throw new UnsupportedOperationException("Should not reach here!") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala index 2da0000dad4..07887666406 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PandasGroupUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.api.python.BasePythonRunner import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.GroupedIterator import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} /** @@ -88,9 +88,10 @@ private[python] object PandasGroupUtils { * argOffsets[argOffsets[0]+2 .. ] is the arg offsets for data attributes */ def resolveArgOffsets( - child: SparkPlan, groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { + attributes: Seq[Attribute], + groupingAttributes: Seq[Attribute]): (Seq[Attribute], Array[Int]) = { - val dataAttributes = child.output.drop(groupingAttributes.length) + val dataAttributes = attributes.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala index 6168d0f867a..bf66791183e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala @@ -76,7 +76,6 @@ private[python] trait PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 3f369ac5e97..f386282a0b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.execution.{LocalLimitExec, QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, MergingSessionsExec, ObjectHashAggregateExec, SortAggregateExec, UpdatingSessionsExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike +import org.apache.spark.sql.execution.python.FlatMapGroupsInPandasWithStateExec import org.apache.spark.sql.execution.streaming.sources.WriteToMicroBatchDataSourceV1 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode @@ -62,6 +63,7 @@ class IncrementalExecution( StreamingJoinStrategy :: StatefulAggregationStrategy :: FlatMapGroupsWithStateStrategy :: + FlatMapGroupsInPandasWithStateStrategy :: StreamingRelationStrategy :: StreamingDeduplicationStrategy :: StreamingGlobalLimitStrategy(outputMode) :: Nil @@ -210,6 +212,13 @@ class IncrementalExecution( hasInitialState = hasInitialState ) + case m: FlatMapGroupsInPandasWithStateExec => + m.copy( + stateInfo = Some(nextStatefulOperationStateInfo), + batchTimestampMs = Some(offsetSeqMetadata.batchTimestampMs), + eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs) + ) + case j: StreamingSymmetricHashJoinExec => j.copy( stateInfo = Some(nextStatefulOperationStateInfo), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 01ff72bac7b..022fd1239ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -49,7 +49,7 @@ package object state { } /** Map each partition of an RDD along with data in a [[StateStore]]. */ - private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( + def mapPartitionsWithStateStore[U: ClassTag]( stateInfo: StatefulOperatorStateInfo, keySchema: StructType, valueSchema: StructType, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org