This is an automated email from the ASF dual-hosted git repository.

cvandermerwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 66134f7107f Fix infer_result_type for pcollection with tags. (#37608)
66134f7107f is described below

commit 66134f7107f75ed40b925633c7509c3c2be609c4
Author: claudevdm <[email protected]>
AuthorDate: Tue Feb 17 14:13:43 2026 -0500

    Fix infer_result_type for pcollection with tags. (#37608)
    
    * Fix infer_result_type for pcollection with tags.
    
    * lint
    
    ---------
    
    Co-authored-by: Claude <[email protected]>
---
 sdks/python/apache_beam/pipeline.py                |  6 +-
 .../typehints/tagged_output_typehints_test.py      | 84 ++++++++++++++++++++++
 2 files changed, 89 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/pipeline.py 
b/sdks/python/apache_beam/pipeline.py
index a6080f2f3e7..3cce2c5bb77 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -990,7 +990,11 @@ class Pipeline(HasDisplayData):
           input_element_types_tuple[0] if len(input_element_types_tuple) == 1
           else typehints.Union[input_element_types_tuple])
       type_hints = transform.get_type_hints()
-      declared_output_type = type_hints.simple_output_type(transform.label)
+      if not result_pcollection.tag:
+        declared_output_type = type_hints.simple_output_type(transform.label)
+      else:
+        declared_output_type = type_hints.tagged_output_types().get(
+            result_pcollection.tag, typehints.Any)
 
       if declared_output_type:
         input_types = type_hints.input_types
diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py 
b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
index c06f68fb88a..c5ec26ba0a6 100644
--- a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
+++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
@@ -41,6 +41,7 @@ from typing import Union
 
 import apache_beam as beam
 from apache_beam.pvalue import TaggedOutput
+from apache_beam.typehints import Any
 from apache_beam.typehints import with_output_types
 from apache_beam.typehints.decorators import IOTypeHints
 
@@ -352,5 +353,88 @@ class AnnotationStyleTaggedOutputTest(unittest.TestCase):
       self.assertEqual(results.errors.element_type, str)
 
 
+class CompositeTaggedOutputInferenceTest(unittest.TestCase):
+  """Tests for _infer_result_type when a composite PTransform returns
+  tagged outputs as a dict of fresh PCollections.
+  """
+  def test_composite_returning_tagged_dict_preserves_existing_types(self):
+    """A composite that returns a dict of PCollections already typed by
+    DoOutputsTuple.__getitem__ should preserve those types through
+    _infer_result_type (the guard skips inference when element_type is set)."""
+    class MyComposite(beam.PTransform):
+      def expand(self, pcoll):
+        results = (
+            pcoll
+            | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+        # Return a dict of the DoOutputsTuple's PCollections.
+        # These already have types set via __getitem__.
+        return {'main': results.main, 'errors': results.errors}
+
+      @beam.typehints.with_output_types(int, errors=str)
+      class _MyDoFn(beam.DoFn):
+        def process(self, element):
+          if element < 0:
+            yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+          else:
+            yield element
+
+    with beam.Pipeline() as p:
+      result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+      self.assertEqual(result['main'].element_type, int)
+      self.assertEqual(result['errors'].element_type, str)
+
+  def test_composite_returning_tagged_dict_without_dofn_hints_is_any(self):
+    class MyComposite(beam.PTransform):
+      def expand(self, pcoll):
+        results = (
+            pcoll
+            | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+        return {'main': results.main, 'errors': results.errors}
+
+      class _MyDoFn(beam.DoFn):
+        def process(self, element):
+          if element < 0:
+            yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+          else:
+            yield element
+
+    with beam.Pipeline() as p:
+      result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+      self.assertEqual(result['errors'].element_type, Any)
+
+  def test_composite_pcollections_uses_tagged_type_hints(self):
+    """A composite that creates new PCollections (element_type=None) and
+    returns them as a dict should still get correct tagged types from
+    the type hints."""
+    @beam.typehints.with_output_types(int, errors=str)
+    class MyComposite(beam.PTransform):
+      def expand(self, pcoll):
+        results = (
+            pcoll
+            | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+        return {
+            p.tag if p.tag else results._main_tag: beam.pvalue.PCollection(
+                pcoll.pipeline, tag=p.tag)
+            for p in results
+        }
+
+      class _MyDoFn(beam.DoFn):
+        def process(
+            self, element
+        ) -> Iterable[int | beam.TaggedOutput[Literal['errors'], str]]:
+          if element < 0:
+            yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+          else:
+            yield element
+
+    with beam.Pipeline() as p:
+      result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+      self.assertEqual(result['main'].element_type, int)
+      self.assertEqual(result['errors'].element_type, str)
+
+
 if __name__ == '__main__':
   unittest.main()

Reply via email to