This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 71b1e831fd8 fix provider testing issue (#37183)
71b1e831fd8 is described below

commit 71b1e831fd87f2d5d0b5368f965f5d1b20acb4a9
Author: Derrick Williams <[email protected]>
AuthorDate: Fri Dec 26 07:56:45 2025 -0500

    fix provider testing issue (#37183)
---
 sdks/python/apache_beam/yaml/yaml_provider.py     |  2 +-
 sdks/python/apache_beam/yaml/yaml_testing.py      | 23 ++++++++++-----
 sdks/python/apache_beam/yaml/yaml_testing_test.py | 34 +++++++++++++++++++++++
 3 files changed, 51 insertions(+), 8 deletions(-)

diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index b8c7f4f7a87..e9882602d10 100755
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -490,7 +490,7 @@ class YamlProvider(Provider):
     return dict(
         type='object',
         additionalProperties=False,
-        **self._transforms[type]['config_schema'])
+        **self._transforms[type].get('config_schema', {}))
 
   def description(self, type):
     return self._transforms[type].get('description')
diff --git a/sdks/python/apache_beam/yaml/yaml_testing.py 
b/sdks/python/apache_beam/yaml/yaml_testing.py
index e7fbc1d43b6..ad31afa927e 100644
--- a/sdks/python/apache_beam/yaml/yaml_testing.py
+++ b/sdks/python/apache_beam/yaml/yaml_testing.py
@@ -73,12 +73,15 @@ class YamlTestCase(unittest.TestCase):
 
 def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
   if isinstance(pipeline_spec, str):
-    pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
+    pipeline_spec_dict = yaml.load(
+        pipeline_spec, Loader=yaml_utils.SafeLineLoader)
+  else:
+    pipeline_spec_dict = pipeline_spec
 
-  pipeline_spec = _preprocess_for_testing(pipeline_spec)
+  processed_pipeline_spec = _preprocess_for_testing(pipeline_spec_dict)
 
   transform_spec, recording_ids = inject_test_tranforms(
-      pipeline_spec,
+      processed_pipeline_spec,
       test_spec,
       fix_failures)
 
@@ -96,12 +99,18 @@ def run_test(pipeline_spec, test_spec, options=None, 
fix_failures=False):
     options = beam.options.pipeline_options.PipelineOptions(
         pickle_library='cloudpickle',
         **yaml_transform.SafeLineLoader.strip_metadata(
-            pipeline_spec.get('options', {})))
+            pipeline_spec_dict.get('options', {})))
+
+  providers = yaml_provider.merge_providers(
+      yaml_provider.parse_providers(
+          '', pipeline_spec_dict.get('providers', [])),
+      {
+          'AssertEqualAndRecord': yaml_provider.as_provider_list(
+              'AssertEqualAndRecord', AssertEqualAndRecord)
+      })
 
   with beam.Pipeline(options=options) as p:
-    _ = p | yaml_transform.YamlTransform(
-        transform_spec,
-        providers={'AssertEqualAndRecord': AssertEqualAndRecord})
+    _ = p | yaml_transform.YamlTransform(transform_spec, providers=providers)
 
   if fix_failures:
     fixes = {}
diff --git a/sdks/python/apache_beam/yaml/yaml_testing_test.py 
b/sdks/python/apache_beam/yaml/yaml_testing_test.py
index 9fcdafd2ab3..9bb0e64b6db 100644
--- a/sdks/python/apache_beam/yaml/yaml_testing_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_testing_test.py
@@ -322,6 +322,40 @@ class YamlTestingTest(unittest.TestCase):
             }]
         })
 
+  def test_toplevel_providers(self):
+    yaml_testing.run_test(
+        '''
+        pipeline:
+          type: chain
+          transforms:
+            - type: Create
+              config:
+                elements: [1, 2, 3]
+            - type: MyDoubler
+        providers:
+          - type: yaml
+            transforms:
+              MyDoubler:
+                body:
+                  type: MapToFields
+                  config:
+                    language: python
+                    fields:
+                      doubled: element * 2
+        ''',
+        {
+            'expected_outputs': [{
+                'name': 'MyDoubler',
+                'elements': [{
+                    'doubled': 2
+                }, {
+                    'doubled': 4
+                }, {
+                    'doubled': 6
+                }]
+            }]
+        })
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to