This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f91bb68b2e9 Add some simple annotations to Python transforms. (#28191)
f91bb68b2e9 is described below

commit f91bb68b2e9df3788914898c03fdf42445f912d5
Author: Robert Bradshaw <rober...@gmail.com>
AuthorDate: Wed Aug 30 17:23:33 2023 -0700

    Add some simple annotations to Python transforms. (#28191)
---
 sdks/python/apache_beam/ml/inference/base.py              |  9 +++++++++
 sdks/python/apache_beam/transforms/ptransform.py          |  5 ++++-
 sdks/python/apache_beam/yaml/yaml_transform_scope_test.py | 10 +++++++---
 3 files changed, 20 insertions(+), 4 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index b5aa4f352fa..0964fc46a95 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -952,6 +952,15 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
     # allow us to effectively disambiguate in multi-model settings.
     self._model_tag = uuid.uuid4().hex
 
+  def annotations(self):
+    return {
+        'model_handler': str(self._model_handler),
+        'model_handler_type': (
+            f'{self._model_handler.__class__.__module__}'
+            f'.{self._model_handler.__class__.__qualname__}'),
+        **super().annotations()
+    }
+
   def _get_model_metadata_pcoll(self, pipeline):
     # avoid circular imports.
     # pylint: disable=wrong-import-position
diff --git a/sdks/python/apache_beam/transforms/ptransform.py 
b/sdks/python/apache_beam/transforms/ptransform.py
index fd86ff1f934..c7eaa152ae0 100644
--- a/sdks/python/apache_beam/transforms/ptransform.py
+++ b/sdks/python/apache_beam/transforms/ptransform.py
@@ -371,7 +371,10 @@ class PTransform(WithTypeHints, HasDisplayData, 
Generic[InputT, OutputT]):
     return self.__class__.__name__
 
   def annotations(self) -> Dict[str, Union[bytes, str, message.Message]]:
-    return {}
+    return {
+        'python_type':  #
+        f'{self.__class__.__module__}.{self.__class__.__qualname__}'
+    }
 
   def default_type_hints(self):
     fn_type_hints = IOTypeHints.from_callable(self.expand)
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py 
b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
index ead5d5d66d2..a22e4f851a1 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_scope_test.py
@@ -61,7 +61,7 @@ class ScopeTest(unittest.TestCase):
           - type: PyMap
             name: Square
             input: Create
-            config: 
+            config:
               fn: "lambda x: x*x"
         '''
 
@@ -123,7 +123,11 @@ class ScopeTest(unittest.TestCase):
       self.assertIsInstance(result, beam.transforms.ParDo)
       self.assertEqual(result.label, 'Map(lambda x: x*x)')
 
-      result_annotations = {**result.annotations()}
+      result_annotations = {
+          key: value
+          for (key, value) in result.annotations().items()
+          if key.startswith('yaml')
+      }
       target_annotations = {
           'yaml_type': 'PyMap',
           'yaml_args': '{"fn": "lambda x: x*x"}',
@@ -146,7 +150,7 @@ class LightweightScopeTest(unittest.TestCase):
             fn: "lambda x: x * x * x"
           - type: Filter
             name: FilterOutBigNumbers
-            input: PyMap 
+            input: PyMap
             keep: "lambda x: x<100"
           '''
     return yaml.load(pipeline_yaml, Loader=SafeLineLoader)

Reply via email to