This is an automated email from the ASF dual-hosted git repository.
cvandermerwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new 66134f7107f Fix infer_result_type for pcollection with tags. (#37608)
66134f7107f is described below
commit 66134f7107f75ed40b925633c7509c3c2be609c4
Author: claudevdm <[email protected]>
AuthorDate: Tue Feb 17 14:13:43 2026 -0500
Fix infer_result_type for pcollection with tags. (#37608)
* Fix infer_result_type for pcollection with tags.
* lint
---------
Co-authored-by: Claude <[email protected]>
---
sdks/python/apache_beam/pipeline.py | 6 +-
.../typehints/tagged_output_typehints_test.py | 84 ++++++++++++++++++++++
2 files changed, 89 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/pipeline.py
b/sdks/python/apache_beam/pipeline.py
index a6080f2f3e7..3cce2c5bb77 100644
--- a/sdks/python/apache_beam/pipeline.py
+++ b/sdks/python/apache_beam/pipeline.py
@@ -990,7 +990,11 @@ class Pipeline(HasDisplayData):
input_element_types_tuple[0] if len(input_element_types_tuple) == 1
else typehints.Union[input_element_types_tuple])
type_hints = transform.get_type_hints()
- declared_output_type = type_hints.simple_output_type(transform.label)
+ if not result_pcollection.tag:
+ declared_output_type = type_hints.simple_output_type(transform.label)
+ else:
+ declared_output_type = type_hints.tagged_output_types().get(
+ result_pcollection.tag, typehints.Any)
if declared_output_type:
input_types = type_hints.input_types
diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
index c06f68fb88a..c5ec26ba0a6 100644
--- a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
+++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py
@@ -41,6 +41,7 @@ from typing import Union
import apache_beam as beam
from apache_beam.pvalue import TaggedOutput
+from apache_beam.typehints import Any
from apache_beam.typehints import with_output_types
from apache_beam.typehints.decorators import IOTypeHints
@@ -352,5 +353,88 @@ class AnnotationStyleTaggedOutputTest(unittest.TestCase):
self.assertEqual(results.errors.element_type, str)
+class CompositeTaggedOutputInferenceTest(unittest.TestCase):
+ """Tests for _infer_result_type when a composite PTransform returns
+ tagged outputs as a dict of fresh PCollections.
+ """
+ def test_composite_returning_tagged_dict_preserves_existing_types(self):
+ """A composite that returns a dict of PCollections already typed by
+ DoOutputsTuple.__getitem__ should preserve those types through
+ _infer_result_type (the guard skips inference when element_type is set)."""
+ class MyComposite(beam.PTransform):
+ def expand(self, pcoll):
+ results = (
+ pcoll
+ | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+ # Return a dict of the DoOutputsTuple's PCollections.
+ # These already have types set via __getitem__.
+ return {'main': results.main, 'errors': results.errors}
+
+ @beam.typehints.with_output_types(int, errors=str)
+ class _MyDoFn(beam.DoFn):
+ def process(self, element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element
+
+ with beam.Pipeline() as p:
+ result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+ self.assertEqual(result['main'].element_type, int)
+ self.assertEqual(result['errors'].element_type, str)
+
+ def test_composite_returning_tagged_dict_without_dofn_hints_is_any(self):
+ class MyComposite(beam.PTransform):
+ def expand(self, pcoll):
+ results = (
+ pcoll
+ | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+ return {'main': results.main, 'errors': results.errors}
+
+ class _MyDoFn(beam.DoFn):
+ def process(self, element):
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element
+
+ with beam.Pipeline() as p:
+ result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+ self.assertEqual(result['errors'].element_type, Any)
+
+ def test_composite_pcollections_uses_tagged_type_hints(self):
+ """A composite that creates new PCollections (element_type=None) and
+ returns them as a dict should still get correct tagged types from
+ the type hints."""
+ @beam.typehints.with_output_types(int, errors=str)
+ class MyComposite(beam.PTransform):
+ def expand(self, pcoll):
+ results = (
+ pcoll
+ | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+ return {
+ p.tag if p.tag else results._main_tag: beam.pvalue.PCollection(
+ pcoll.pipeline, tag=p.tag)
+ for p in results
+ }
+
+ class _MyDoFn(beam.DoFn):
+ def process(
+ self, element
+ ) -> Iterable[int | beam.TaggedOutput[Literal['errors'], str]]:
+ if element < 0:
+ yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+ else:
+ yield element
+
+ with beam.Pipeline() as p:
+ result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+ self.assertEqual(result['main'].element_type, int)
+ self.assertEqual(result['errors'].element_type, str)
+
+
if __name__ == '__main__':
unittest.main()