This is an automated email from the ASF dual-hosted git repository. jrmccluskey pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 84d070bae9b Add per-worker keys to stateful BatchElements, plumb through batching options to all model handlers (#29642) 84d070bae9b is described below commit 84d070bae9b59160d4d88aa2fab9d98b4ed2589f Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com> AuthorDate: Thu Dec 7 12:40:26 2023 -0500 Add per-worker keys to stateful BatchElements, plumb through batching options to all model handlers (#29642) * Key streaming BatchElements bundles per-worker * Plumb through max_batch_duration_secs support * Linting * Tag shared handle --- .../ml/inference/huggingface_inference.py | 15 +++++++++++ .../apache_beam/ml/inference/onnx_inference.py | 18 ++++++++++++++ .../apache_beam/ml/inference/pytorch_inference.py | 10 ++++++++ .../apache_beam/ml/inference/sklearn_inference.py | 10 ++++++++ .../ml/inference/tensorflow_inference.py | 10 ++++++++ .../apache_beam/ml/inference/tensorrt_inference.py | 7 +++++- .../ml/inference/vertex_ai_inference.py | 22 +++++++++++++++- .../apache_beam/ml/inference/xgboost_inference.py | 21 ++++++++++++++++ sdks/python/apache_beam/transforms/util.py | 29 +++++++++++++++++++++- 9 files changed, 139 insertions(+), 3 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py b/sdks/python/apache_beam/ml/inference/huggingface_inference.py index 878d7bfc9cf..1bc92c462c9 100644 --- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py +++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py @@ -224,6 +224,7 @@ class HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str, inference_args: Optional[Dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """ @@ -255,6 +256,8 @@ class HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str, Defaults to None. min_batch_size: the minimum batch size to use when batching inputs. max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -277,6 +280,8 @@ class HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str, self._batching_kwargs["min_batch_size"] = min_batch_size if max_batch_size is not None: self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._large_model = large_model self._framework = framework @@ -405,6 +410,7 @@ class HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor], inference_args: Optional[Dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """ @@ -436,6 +442,8 @@ class HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor], Defaults to None. min_batch_size: the minimum batch size to use when batching inputs. max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -458,6 +466,8 @@ class HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor], self._batching_kwargs["min_batch_size"] = min_batch_size if max_batch_size is not None: self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._large_model = large_model self._framework = "" @@ -579,6 +589,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str, inference_args: Optional[Dict[str, Any]] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """ @@ -617,6 +628,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str, Defaults to None. min_batch_size: the minimum batch size to use when batching inputs. max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -639,6 +652,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._large_model = large_model # Check if the device is specified twice. If true then the device parameter diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py b/sdks/python/apache_beam/ml/inference/onnx_inference.py index 18e115a6188..f7b6c0115af 100644 --- a/sdks/python/apache_beam/ml/inference/onnx_inference.py +++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py @@ -19,6 +19,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Iterable +from typing import Mapping from typing import Optional from typing import Sequence @@ -63,6 +64,9 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray, *, inference_fn: NumpyInferenceFn = default_numpy_inference_fn, large_model: bool = False, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, **kwargs): """ Implementation of the ModelHandler interface for onnx using numpy arrays as input. @@ -80,6 +84,10 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray, memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should set this to True if N*W > M. + min_batch_size: the minimum batch size to use when batching inputs. + max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. kwargs: 'env_vars' can be used to set environment variables before loading the model. """ @@ -90,6 +98,13 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray, self._model_inference_fn = inference_fn self._env_vars = kwargs.get('env_vars', {}) self._large_model = large_model + self._batching_kwargs = {} + if min_batch_size is not None: + self._batching_kwargs["min_batch_size"] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs def load_model(self) -> ort.InferenceSession: """Loads and initializes an onnx inference session for processing.""" @@ -143,3 +158,6 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray, def share_model_across_processes(self) -> bool: return self._large_model + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py b/sdks/python/apache_beam/ml/inference/pytorch_inference.py index 26e593fdd7d..480dc538195 100644 --- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py +++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py @@ -193,6 +193,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, torch_script_model_path: Optional[str] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, load_model_args: Optional[Dict[str, Any]] = None, **kwargs): @@ -227,6 +228,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, batch will be fed into the inference_fn as a Sequence of Tensors. max_batch_size: the maximum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Tensors. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -254,6 +257,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._torch_script_model_path = torch_script_model_path self._load_model_args = load_model_args if load_model_args else {} self._env_vars = kwargs.get('env_vars', {}) @@ -421,6 +426,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], torch_script_model_path: Optional[str] = None, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, load_model_args: Optional[Dict[str, Any]] = None, **kwargs): @@ -460,6 +466,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], batch will be fed into the inference_fn as a Sequence of Keyed Tensors. max_batch_size: the maximum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Keyed Tensors. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -487,6 +495,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, torch.Tensor], self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._torch_script_model_path = torch_script_model_path self._load_model_args = load_model_args if load_model_args else {} self._env_vars = kwargs.get('env_vars', {}) diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py b/sdks/python/apache_beam/ml/inference/sklearn_inference.py index c2bd2cee66e..befeca7f33b 100644 --- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py +++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py @@ -90,6 +90,7 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, inference_fn: NumpyInferenceFn = _default_numpy_inference_fn, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """ Implementation of the ModelHandler interface for scikit-learn @@ -111,6 +112,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, max_batch_size: the maximum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Numpy ndarrays. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -126,6 +129,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._env_vars = kwargs.get('env_vars', {}) self._large_model = large_model @@ -212,6 +217,7 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, inference_fn: PandasInferenceFn = _default_pandas_inference_fn, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """Implementation of the ModelHandler interface for scikit-learn that @@ -236,6 +242,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, max_batch_size: the maximum batch size to use when batching inputs. This batch will be fed into the inference_fn as a Sequence of Pandas Dataframes. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -251,6 +259,8 @@ class SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._env_vars = kwargs.get('env_vars', {}) self._large_model = large_model diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py index 991ae971d9e..0802868a1dd 100644 --- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py @@ -110,6 +110,7 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray, inference_fn: TensorInferenceFn = default_numpy_inference_fn, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -154,6 +155,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._large_model = large_model def load_model(self) -> tf.Module: @@ -235,6 +238,7 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, inference_fn: TensorInferenceFn = default_tensor_inference_fn, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, large_model: bool = False, **kwargs): """Implementation of the ModelHandler interface for Tensorflow. @@ -258,6 +262,10 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, once the model is loaded. inference_fn: inference function to use during RunInference. Defaults to default_numpy_inference_fn. + min_batch_size: the minimum batch size to use when batching inputs. + max_batch_size: the maximum batch size to use when batching inputs. + max_batch_duration_secs: the maximum amount of time to buffer a batch + before emitting; used in streaming contexts. large_model: set to true if your model is large enough to run into memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should @@ -280,6 +288,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult, self._batching_kwargs['min_batch_size'] = min_batch_size if max_batch_size is not None: self._batching_kwargs['max_batch_size'] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs self._large_model = large_model def load_model(self) -> tf.Module: diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py index ff9bb78d579..53b81c0c36c 100644 --- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py +++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py @@ -230,6 +230,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, *, inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn, large_model: bool = False, + max_batch_duration_secs: Optional[int] = None, **kwargs): """Implementation of the ModelHandler interface for TensorRT. @@ -253,6 +254,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, memory pressure if you load multiple copies. Given a model that consumes N memory and a machine with W cores and M memory, you should set this to True if N*W > M. + max_batch_duration_secs: the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. kwargs: Additional arguments like 'engine_path' and 'onnx_path' are currently supported. 'env_vars' can be used to set environment variables before loading the model. @@ -262,6 +265,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, """ self.min_batch_size = min_batch_size self.max_batch_size = max_batch_size + self.max_batch_duration_secs = max_batch_duration_secs self.inference_fn = inference_fn if 'engine_path' in kwargs: self.engine_path = kwargs.get('engine_path') @@ -274,7 +278,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray, """Sets min_batch_size and max_batch_size of a TensorRT engine.""" return { 'min_batch_size': self.min_batch_size, - 'max_batch_size': self.max_batch_size + 'max_batch_size': self.max_batch_size, + 'max_batch_duration_secs': self.max_batch_duration_secs } def load_model(self) -> TensorRTEngine: diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py index 8c902421f60..95660441a84 100644 --- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py +++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py @@ -20,6 +20,7 @@ import time from typing import Any from typing import Dict from typing import Iterable +from typing import Mapping from typing import Optional from typing import Sequence @@ -69,6 +70,10 @@ class VertexAIModelHandlerJSON(ModelHandler[Any, experiment: Optional[str] = None, network: Optional[str] = None, private: bool = False, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, **kwargs): """Implementation of the ModelHandler interface for Vertex AI. **NOTE:** This API and its implementation are under development and @@ -97,9 +102,21 @@ class VertexAIModelHandlerJSON(ModelHandler[Any, private: optional. if the deployed Vertex AI endpoint is private, set to true. Requires a network to be provided as well. + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. """ - + self._batching_kwargs = {} self._env_vars = kwargs.get('env_vars', {}) + if min_batch_size is not None: + self._batching_kwargs["min_batch_size"] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs if private and network is None: raise ValueError( @@ -231,3 +248,6 @@ class VertexAIModelHandlerJSON(ModelHandler[Any, def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): pass + + def batch_elements_kwargs(self) -> Mapping[str, Any]: + return self._batching_kwargs diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py b/sdks/python/apache_beam/ml/inference/xgboost_inference.py index 374980c56ec..ff6f098b415 100644 --- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py +++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py @@ -21,6 +21,7 @@ from typing import Any from typing import Callable from typing import Dict from typing import Iterable +from typing import Mapping from typing import Optional from typing import Sequence from typing import Union @@ -75,6 +76,10 @@ class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC): Callable[..., xgboost.XGBModel]], model_state: str, inference_fn: XGBoostInferenceFn = default_xgboost_inference_fn, + *, + min_batch_size: Optional[int] = None, + max_batch_size: Optional[int] = None, + max_batch_duration_secs: Optional[int] = None, **kwargs): """Implementation of the ModelHandler interface for XGBoost. @@ -95,6 +100,12 @@ class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC): configuration. inference_fn: the inference function to use during RunInference. default=default_xgboost_inference_fn + min_batch_size: optional. the minimum batch size to use when batching + inputs. + max_batch_size: optional. the maximum batch size to use when batching + inputs. + max_batch_duration_secs: optional. the maximum amount of time to buffer + a batch before emitting; used in streaming contexts. kwargs: 'env_vars' can be used to set environment variables before loading the model. @@ -115,6 +126,13 @@ class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC): self._model_state = model_state self._inference_fn = inference_fn self._env_vars = kwargs.get('env_vars', {}) + self._batching_kwargs = {} + if min_batch_size is not None: + self._batching_kwargs["min_batch_size"] = min_batch_size + if max_batch_size is not None: + self._batching_kwargs["max_batch_size"] = max_batch_size + if max_batch_duration_secs is not None: + self._batching_kwargs["max_batch_duration_secs"] = max_batch_duration_secs def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]: model = self._model_class() @@ -129,6 +147,9 @@ class XGBoostModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC): def get_metrics_namespace(self) -> str: return 'BeamML_XGBoost' + def batch_elements_kwargs(self) -> Mapping[str, Any]: + return self._batching_kwargs + class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray, PredictionResult, diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index cacfdb37d7b..c554bef6c36 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -70,6 +70,7 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature from apache_beam.typehints.sharded_key_type import ShardedKeyType +from apache_beam.utils import shared from apache_beam.utils import windowed_value from apache_beam.utils.annotations import deprecated from apache_beam.utils.sharded_key import ShardedKey @@ -748,6 +749,32 @@ def _pardo_stateful_batch_elements( return _StatefulBatchElementsDoFn() +class SharedKey(): + """A class that holds a per-process UUID used to key elements for streaming + BatchElements. + """ + def __init__(self): + self.key = uuid.uuid4().hex + + +def load_shared_key(): + return SharedKey() + + +class WithSharedKey(DoFn): + """A DoFn that keys elements with a per-process UUID. Used in streaming + BatchElements. + """ + def __init__(self): + self.shared_handle = shared.Shared() + + def setup(self): + self.key = self.shared_handle.acquire(load_shared_key, "WithSharedKey").key + + def process(self, element): + yield (self.key, element) + + @typehints.with_input_types(T) @typehints.with_output_types(List[T]) class BatchElements(PTransform): @@ -826,7 +853,7 @@ class BatchElements(PTransform): raise NotImplementedError("Requires stateful processing (BEAM-2687)") elif self._max_batch_dur is not None: coder = coders.registry.get_coder(pcoll) - return pcoll | WithKeys(0) | ParDo( + return pcoll | ParDo(WithSharedKey()) | ParDo( _pardo_stateful_batch_elements( coder, self._batch_size_estimator,