This is an automated email from the ASF dual-hosted git repository.

jrmccluskey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 84d070bae9b Add per-worker keys to stateful BatchElements, plumb 
through batching options to all model handlers (#29642)
84d070bae9b is described below

commit 84d070bae9b59160d4d88aa2fab9d98b4ed2589f
Author: Jack McCluskey <34928439+jrmcclus...@users.noreply.github.com>
AuthorDate: Thu Dec 7 12:40:26 2023 -0500

    Add per-worker keys to stateful BatchElements, plumb through batching 
options to all model handlers (#29642)
    
    * Key streaming BatchElements bundles per-worker
    
    * Plumb through max_batch_duration_secs support
    
    * Linting
    
    * Tag shared handle
---
 .../ml/inference/huggingface_inference.py          | 15 +++++++++++
 .../apache_beam/ml/inference/onnx_inference.py     | 18 ++++++++++++++
 .../apache_beam/ml/inference/pytorch_inference.py  | 10 ++++++++
 .../apache_beam/ml/inference/sklearn_inference.py  | 10 ++++++++
 .../ml/inference/tensorflow_inference.py           | 10 ++++++++
 .../apache_beam/ml/inference/tensorrt_inference.py |  7 +++++-
 .../ml/inference/vertex_ai_inference.py            | 22 +++++++++++++++-
 .../apache_beam/ml/inference/xgboost_inference.py  | 21 ++++++++++++++++
 sdks/python/apache_beam/transforms/util.py         | 29 +++++++++++++++++++++-
 9 files changed, 139 insertions(+), 3 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py 
b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
index 878d7bfc9cf..1bc92c462c9 100644
--- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
@@ -224,6 +224,7 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
       inference_args: Optional[Dict[str, Any]] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """
@@ -255,6 +256,8 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
         Defaults to None.
       min_batch_size: the minimum batch size to use when batching inputs.
       max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -277,6 +280,8 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
       self._batching_kwargs["min_batch_size"] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs["max_batch_size"] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._large_model = large_model
     self._framework = framework
 
@@ -405,6 +410,7 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       inference_args: Optional[Dict[str, Any]] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """
@@ -436,6 +442,8 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
         Defaults to None.
       min_batch_size: the minimum batch size to use when batching inputs.
       max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -458,6 +466,8 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       self._batching_kwargs["min_batch_size"] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs["max_batch_size"] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._large_model = large_model
     self._framework = ""
 
@@ -579,6 +589,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       inference_args: Optional[Dict[str, Any]] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """
@@ -617,6 +628,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
         Defaults to None.
       min_batch_size: the minimum batch size to use when batching inputs.
       max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -639,6 +652,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._large_model = large_model
 
     # Check if the device is specified twice. If true then the device parameter
diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py 
b/sdks/python/apache_beam/ml/inference/onnx_inference.py
index 18e115a6188..f7b6c0115af 100644
--- a/sdks/python/apache_beam/ml/inference/onnx_inference.py
+++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py
@@ -19,6 +19,7 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Iterable
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
 
@@ -63,6 +64,9 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
       *,
       inference_fn: NumpyInferenceFn = default_numpy_inference_fn,
       large_model: bool = False,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       **kwargs):
     """ Implementation of the ModelHandler interface for onnx
     using numpy arrays as input.
@@ -80,6 +84,10 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      min_batch_size: the minimum batch size to use when batching inputs.
+      max_batch_size: the maximum batch size to use when batching inputs.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+        before emitting; used in streaming contexts.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
@@ -90,6 +98,13 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
     self._model_inference_fn = inference_fn
     self._env_vars = kwargs.get('env_vars', {})
     self._large_model = large_model
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs["min_batch_size"] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs["max_batch_size"] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
   def load_model(self) -> ort.InferenceSession:
     """Loads and initializes an onnx inference session for processing."""
@@ -143,3 +158,6 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
 
   def share_model_across_processes(self) -> bool:
     return self._large_model
+
+  def batch_elements_kwargs(self) -> Mapping[str, Any]:
+    return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 26e593fdd7d..480dc538195 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -193,6 +193,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       torch_script_model_path: Optional[str] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
@@ -227,6 +228,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
         batch will be fed into the inference_fn as a Sequence of Tensors.
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Tensors.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+          before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -254,6 +257,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
@@ -421,6 +426,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
       torch_script_model_path: Optional[str] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
@@ -460,6 +466,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
         batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Keyed Tensors.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+          before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -487,6 +495,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index c2bd2cee66e..befeca7f33b 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -90,6 +90,7 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       inference_fn: NumpyInferenceFn = _default_numpy_inference_fn,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """ Implementation of the ModelHandler interface for scikit-learn
@@ -111,6 +112,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Numpy
         ndarrays.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+          before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -126,6 +129,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._env_vars = kwargs.get('env_vars', {})
     self._large_model = large_model
 
@@ -212,6 +217,7 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       inference_fn: PandasInferenceFn = _default_pandas_inference_fn,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """Implementation of the ModelHandler interface for scikit-learn that
@@ -236,6 +242,8 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       max_batch_size: the maximum batch size to use when batching inputs. This
         batch will be fed into the inference_fn as a Sequence of Pandas
         Dataframes.
+      max_batch_duration_secs: the maximum amount of time to buffer a batch
+          before emitting; used in streaming contexts.
       large_model: set to true if your model is large enough to run into
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
@@ -251,6 +259,8 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._env_vars = kwargs.get('env_vars', {})
     self._large_model = large_model
 
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index 991ae971d9e..0802868a1dd 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -110,6 +110,7 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       inference_fn: TensorInferenceFn = default_numpy_inference_fn,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
@@ -154,6 +155,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._large_model = large_model
 
   def load_model(self) -> tf.Module:
@@ -235,6 +238,7 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
       inference_fn: TensorInferenceFn = default_tensor_inference_fn,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
@@ -258,6 +262,10 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
           once the model is loaded.
         inference_fn: inference function to use during RunInference.
           Defaults to default_numpy_inference_fn.
+        min_batch_size: the minimum batch size to use when batching inputs.
+        max_batch_size: the maximum batch size to use when batching inputs.
+        max_batch_duration_secs: the maximum amount of time to buffer a batch
+          before emitting; used in streaming contexts.
         large_model: set to true if your model is large enough to run into
           memory pressure if you load multiple copies. Given a model that
           consumes N memory and a machine with W cores and M memory, you should
@@ -280,6 +288,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
       self._batching_kwargs['min_batch_size'] = min_batch_size
     if max_batch_size is not None:
       self._batching_kwargs['max_batch_size'] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._large_model = large_model
 
   def load_model(self) -> tf.Module:
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index ff9bb78d579..53b81c0c36c 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -230,6 +230,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
       *,
       inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn,
       large_model: bool = False,
+      max_batch_duration_secs: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for TensorRT.
 
@@ -253,6 +254,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      max_batch_duration_secs: the maximum amount of time to buffer 
+        a batch before emitting; used in streaming contexts.
       kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
         currently supported. 'env_vars' can be used to set environment 
variables
         before loading the model.
@@ -262,6 +265,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     """
     self.min_batch_size = min_batch_size
     self.max_batch_size = max_batch_size
+    self.max_batch_duration_secs = max_batch_duration_secs
     self.inference_fn = inference_fn
     if 'engine_path' in kwargs:
       self.engine_path = kwargs.get('engine_path')
@@ -274,7 +278,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     """Sets min_batch_size and max_batch_size of a TensorRT engine."""
     return {
         'min_batch_size': self.min_batch_size,
-        'max_batch_size': self.max_batch_size
+        'max_batch_size': self.max_batch_size,
+        'max_batch_duration_secs': self.max_batch_duration_secs
     }
 
   def load_model(self) -> TensorRTEngine:
diff --git a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py 
b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
index 8c902421f60..95660441a84 100644
--- a/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
+++ b/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
@@ -20,6 +20,7 @@ import time
 from typing import Any
 from typing import Dict
 from typing import Iterable
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
 
@@ -69,6 +70,10 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
       experiment: Optional[str] = None,
       network: Optional[str] = None,
       private: bool = False,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Vertex AI.
     **NOTE:** This API and its implementation are under development and
@@ -97,9 +102,21 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
       private: optional. if the deployed Vertex AI endpoint is
         private, set to true. Requires a network to be provided
         as well.
+      min_batch_size: optional. the minimum batch size to use when batching
+        inputs.
+      max_batch_size: optional. the maximum batch size to use when batching
+        inputs.
+      max_batch_duration_secs: optional. the maximum amount of time to buffer 
+        a batch before emitting; used in streaming contexts.
     """
-
+    self._batching_kwargs = {}
     self._env_vars = kwargs.get('env_vars', {})
+    if min_batch_size is not None:
+      self._batching_kwargs["min_batch_size"] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs["max_batch_size"] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
     if private and network is None:
       raise ValueError(
@@ -231,3 +248,6 @@ class VertexAIModelHandlerJSON(ModelHandler[Any,
 
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
     pass
+
+  def batch_elements_kwargs(self) -> Mapping[str, Any]:
+    return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/xgboost_inference.py 
b/sdks/python/apache_beam/ml/inference/xgboost_inference.py
index 374980c56ec..ff6f098b415 100644
--- a/sdks/python/apache_beam/ml/inference/xgboost_inference.py
+++ b/sdks/python/apache_beam/ml/inference/xgboost_inference.py
@@ -21,6 +21,7 @@ from typing import Any
 from typing import Callable
 from typing import Dict
 from typing import Iterable
+from typing import Mapping
 from typing import Optional
 from typing import Sequence
 from typing import Union
@@ -75,6 +76,10 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
                          Callable[..., xgboost.XGBModel]],
       model_state: str,
       inference_fn: XGBoostInferenceFn = default_xgboost_inference_fn,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for XGBoost.
 
@@ -95,6 +100,12 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
         configuration.
       inference_fn: the inference function to use during RunInference.
         default=default_xgboost_inference_fn
+      min_batch_size: optional. the minimum batch size to use when batching
+        inputs.
+      max_batch_size: optional. the maximum batch size to use when batching
+        inputs.
+      max_batch_duration_secs: optional. the maximum amount of time to buffer 
+        a batch before emitting; used in streaming contexts.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -115,6 +126,13 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
     self._model_state = model_state
     self._inference_fn = inference_fn
     self._env_vars = kwargs.get('env_vars', {})
+    self._batching_kwargs = {}
+    if min_batch_size is not None:
+      self._batching_kwargs["min_batch_size"] = min_batch_size
+    if max_batch_size is not None:
+      self._batching_kwargs["max_batch_size"] = max_batch_size
+    if max_batch_duration_secs is not None:
+      self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
 
   def load_model(self) -> Union[xgboost.Booster, xgboost.XGBModel]:
     model = self._model_class()
@@ -129,6 +147,9 @@ class XGBoostModelHandler(ModelHandler[ExampleT, 
PredictionT, ModelT], ABC):
   def get_metrics_namespace(self) -> str:
     return 'BeamML_XGBoost'
 
+  def batch_elements_kwargs(self) -> Mapping[str, Any]:
+    return self._batching_kwargs
+
 
 class XGBoostModelHandlerNumpy(XGBoostModelHandler[numpy.ndarray,
                                                    PredictionResult,
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index cacfdb37d7b..c554bef6c36 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -70,6 +70,7 @@ from apache_beam.transforms.window import TimestampedValue
 from apache_beam.typehints import trivial_inference
 from apache_beam.typehints.decorators import get_signature
 from apache_beam.typehints.sharded_key_type import ShardedKeyType
+from apache_beam.utils import shared
 from apache_beam.utils import windowed_value
 from apache_beam.utils.annotations import deprecated
 from apache_beam.utils.sharded_key import ShardedKey
@@ -748,6 +749,32 @@ def _pardo_stateful_batch_elements(
   return _StatefulBatchElementsDoFn()
 
 
+class SharedKey():
+  """A class that holds a per-process UUID used to key elements for streaming
+  BatchElements. 
+  """
+  def __init__(self):
+    self.key = uuid.uuid4().hex
+
+
+def load_shared_key():
+  return SharedKey()
+
+
+class WithSharedKey(DoFn):
+  """A DoFn that keys elements with a per-process UUID. Used in streaming
+  BatchElements.  
+  """
+  def __init__(self):
+    self.shared_handle = shared.Shared()
+
+  def setup(self):
+    self.key = self.shared_handle.acquire(load_shared_key, "WithSharedKey").key
+
+  def process(self, element):
+    yield (self.key, element)
+
+
 @typehints.with_input_types(T)
 @typehints.with_output_types(List[T])
 class BatchElements(PTransform):
@@ -826,7 +853,7 @@ class BatchElements(PTransform):
       raise NotImplementedError("Requires stateful processing (BEAM-2687)")
     elif self._max_batch_dur is not None:
       coder = coders.registry.get_coder(pcoll)
-      return pcoll | WithKeys(0) | ParDo(
+      return pcoll | ParDo(WithSharedKey()) | ParDo(
           _pardo_stateful_batch_elements(
               coder,
               self._batch_size_estimator,

Reply via email to