This is an automated email from the ASF dual-hosted git repository. robertwb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new f91bb68b2e9 Add some simple annotations to Python transforms. (#28191) f91bb68b2e9 is described below commit f91bb68b2e9df3788914898c03fdf42445f912d5 Author: Robert Bradshaw <rober...@gmail.com> AuthorDate: Wed Aug 30 17:23:33 2023 -0700 Add some simple annotations to Python transforms. (#28191) --- sdks/python/apache_beam/ml/inference/base.py | 9 +++++++++ sdks/python/apache_beam/transforms/ptransform.py | 5 ++++- sdks/python/apache_beam/yaml/yaml_transform_scope_test.py | 10 +++++++--- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index b5aa4f352fa..0964fc46a95 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -952,6 +952,15 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], # allow us to effectively disambiguate in multi-model settings. self._model_tag = uuid.uuid4().hex + def annotations(self): + return { + 'model_handler': str(self._model_handler), + 'model_handler_type': ( + f'{self._model_handler.__class__.__module__}' + f'.{self._model_handler.__class__.__qualname__}'), + **super().annotations() + } + def _get_model_metadata_pcoll(self, pipeline): # avoid circular imports. # pylint: disable=wrong-import-position diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index fd86ff1f934..c7eaa152ae0 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -371,7 +371,10 @@ class PTransform(WithTypeHints, HasDisplayData, Generic[InputT, OutputT]): return self.__class__.__name__ def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]: - return {} + return { + 'python_type': # + f'{self.__class__.__module__}.{self.__class__.__qualname__}' + } def default_type_hints(self): fn_type_hints = IOTypeHints.from_callable(self.expand) diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py index ead5d5d66d2..a22e4f851a1 100644 --- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py +++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py @@ -61,7 +61,7 @@ class ScopeTest(unittest.TestCase): - type: PyMap name: Square input: Create - config: + config: fn: "lambda x: x*x" ''' @@ -123,7 +123,11 @@ class ScopeTest(unittest.TestCase): self.assertIsInstance(result, beam.transforms.ParDo) self.assertEqual(result.label, 'Map(lambda x: x*x)') - result_annotations = {**result.annotations()} + result_annotations = { + key: value + for (key, value) in result.annotations().items() + if key.startswith('yaml') + } target_annotations = { 'yaml_type': 'PyMap', 'yaml_args': '{"fn": "lambda x: x*x"}', @@ -146,7 +150,7 @@ class LightweightScopeTest(unittest.TestCase): fn: "lambda x: x * x * x" - type: Filter name: FilterOutBigNumbers - input: PyMap + input: PyMap keep: "lambda x: x<100" ''' return yaml.load(pipeline_yaml, Loader=SafeLineLoader)