gemini-code-assist[bot] commented on code in PR #37608:
URL: https://github.com/apache/beam/pull/37608#discussion_r2809312984


##########
sdks/python/apache_beam/typehints/tagged_output_typehints_test.py:
##########
@@ -352,5 +353,88 @@ def process(
       self.assertEqual(results.errors.element_type, str)
 
 
+class CompositeTaggedOutputInferenceTest(unittest.TestCase):
+  """Tests for _infer_result_type when a composite PTransform returns
+  tagged outputs as a dict of fresh PCollections.
+  """
+  def test_composite_returning_tagged_dict_preserves_existing_types(self):
+    """A composite that returns a dict of PCollections already typed by
+    DoOutputsTuple.__getitem__ should preserve those types through
+    _infer_result_type (the guard skips inference when element_type is set)."""
+    class MyComposite(beam.PTransform):
+      def expand(self, pcoll):
+        results = (
+            pcoll
+            | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+        # Return a dict of the DoOutputsTuple's PCollections.
+        # These already have types set via __getitem__.
+        return {'main': results.main, 'errors': results.errors}
+
+      @beam.typehints.with_output_types(int, errors=str)
+      class _MyDoFn(beam.DoFn):
+        def process(self, element):
+          if element < 0:
+            yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+          else:
+            yield element
+
+    with beam.Pipeline() as p:
+      result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+      self.assertEqual(result['main'].element_type, int)
+      self.assertEqual(result['errors'].element_type, str)
+
+  def test_composite_returning_tagged_dict_without_dofn_hints_is_any(self):
+    class MyComposite(beam.PTransform):
+      def expand(self, pcoll):
+        results = (
+            pcoll
+            | beam.ParDo(self._MyDoFn()).with_outputs('errors', main='main'))
+        return {'main': results.main, 'errors': results.errors}
+
+      class _MyDoFn(beam.DoFn):
+        def process(self, element):
+          if element < 0:
+            yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}')
+          else:
+            yield element
+
+    with beam.Pipeline() as p:
+      result = (p | beam.Create([-1, 0, 1, 2]) | MyComposite())
+
+      self.assertEqual(result['errors'].element_type, Any)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This test correctly asserts that the tagged output `errors` defaults to 
`Any` when no type hints are provided. For completeness and to ensure the main 
output also behaves as expected in this scenario, it would be beneficial to add 
an assertion for the `main` output's type as well. It should also default to 
`Any`.
   
   ```suggestion
         self.assertEqual(result['main'].element_type, Any)
         self.assertEqual(result['errors'].element_type, Any)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to