This is an automated email from the ASF dual-hosted git repository.

chamikara pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new bd8950176db Use external config schema to construct Python 
SchemaTransform payload (#26100)
bd8950176db is described below

commit bd8950176db0116221a1b739a3916da26d822f2f
Author: Ahmed Abualsaud <65791736+ahmedab...@users.noreply.github.com>
AuthorDate: Wed Apr 19 19:29:44 2023 +0300

    Use external config schema to construct Python SchemaTransform payload 
(#26100)
    
    * use external config schema to construct schematransform payload
    
    * add documentation
    
    * check for extra kwargs fields; rearrange kwargs and use existing proto 
method
    
    * add rearrange kwargs flag to bq storage write
    
    * use ordered dict
---
 sdks/python/apache_beam/io/gcp/bigquery.py         |  2 +-
 sdks/python/apache_beam/transforms/external.py     | 50 ++++++++++++++++++++--
 .../python/apache_beam/transforms/external_test.py | 26 ++++++++++-
 3 files changed, 72 insertions(+), 6 deletions(-)

diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py 
b/sdks/python/apache_beam/io/gcp/bigquery.py
index bbeedba0021..ee708115695 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery.py
@@ -447,7 +447,6 @@ __all__ = [
     'BigQueryQueryPriority',
     'WriteToBigQuery',
     'WriteResult',
-    'StorageWriteToBigQuery',
     'ReadFromBigQuery',
     'ReadFromBigQueryRequest',
     'ReadAllFromBigQuery',
@@ -2467,6 +2466,7 @@ class StorageWriteToBigQuery(PTransform):
     external_storage_write = SchemaAwareExternalTransform(
         identifier=self.schematransform_config.identifier,
         expansion_service=self._expansion_service,
+        rearrange_based_on_discovery=True,
         autoSharding=self._with_auto_sharding,
         createDisposition=self._create_disposition,
         table=self._table,
diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 7884584306c..543cbcca5b2 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -327,17 +327,61 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
   :param expansion_service: an expansion service to use. This should already be
       available and the Schema-aware transforms to be used must already be
       deployed.
+  :param rearrange_based_on_discovery: if this flag is set, the input kwargs
+      will be rearranged to match the order of fields in the external
+      SchemaTransform configuration. A discovery call will be made to fetch
+      the configuration.
   :param classpath: (Optional) A list paths to additional jars to place on the
       expansion service classpath.
   :kwargs: field name to value mapping for configuring the schema transform.
       keys map to the field names of the schema of the SchemaTransform
       (in-order).
   """
-  def __init__(self, identifier, expansion_service, classpath=None, **kwargs):
+  def __init__(
+      self,
+      identifier,
+      expansion_service,
+      rearrange_based_on_discovery=False,
+      classpath=None,
+      **kwargs):
     self._expansion_service = expansion_service
-    self._payload_builder = SchemaTransformPayloadBuilder(identifier, **kwargs)
+    self._kwargs = kwargs
     self._classpath = classpath
 
+    _kwargs = kwargs
+    if rearrange_based_on_discovery:
+      _kwargs = self._rearrange_kwargs(identifier)
+
+    self._payload_builder = SchemaTransformPayloadBuilder(identifier, 
**_kwargs)
+
+  def _rearrange_kwargs(self, identifier):
+    # discover and fetch the external SchemaTransform configuration then
+    # use it to build an appropriate payload
+    schematransform_config = SchemaAwareExternalTransform.discover_config(
+        self._expansion_service, identifier)
+
+    external_config_fields = 
schematransform_config.configuration_schema._fields
+    ordered_kwargs = OrderedDict()
+    missing_fields = []
+
+    for field in external_config_fields:
+      if field not in self._kwargs:
+        missing_fields.append(field)
+      else:
+        ordered_kwargs[field] = self._kwargs[field]
+
+    extra_fields = list(set(self._kwargs.keys()) - set(external_config_fields))
+    if missing_fields:
+      raise ValueError(
+          'Input parameters are missing the following SchemaTransform config '
+          'fields: %s' % missing_fields)
+    elif extra_fields:
+      raise ValueError(
+          'Input parameters include the following extra fields that are not '
+          'found in the SchemaTransform config schema: %s' % extra_fields)
+
+    return ordered_kwargs
+
   def expand(self, pcolls):
     # Expand the transform using the expansion service.
     return pcolls | ExternalTransform(
@@ -371,7 +415,7 @@ class SchemaAwareExternalTransform(ptransform.PTransform):
   def discover_config(expansion_service, name):
     """Discover one SchemaTransform by name in the given expansion service.
 
-    :return: one SchemaTransformConfig that represents the discovered
+    :return: one SchemaTransformsConfig that represents the discovered
         SchemaTransform
 
     :raises:
diff --git a/sdks/python/apache_beam/transforms/external_test.py 
b/sdks/python/apache_beam/transforms/external_test.py
index 650e292bfb9..e7e4de46eb2 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -491,8 +491,11 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
           config_schema=schema_pb2.Schema(
               fields=[
                   schema_pb2.Field(
-                      name="test_field",
-                      type=schema_pb2.FieldType(atomic_type="STRING"))
+                      name="str_field",
+                      type=schema_pb2.FieldType(atomic_type="STRING")),
+                  schema_pb2.Field(
+                      name="int_field",
+                      type=schema_pb2.FieldType(atomic_type="INT64"))
               ],
               id="test-id"),
           input_pcollection_names=["input"],
@@ -517,6 +520,25 @@ class SchemaAwareExternalTransformTest(unittest.TestCase):
       beam.SchemaAwareExternalTransform.discover_config(
           "test_service", name="non_existent")
 
+  @mock.patch("apache_beam.transforms.external.ExternalTransform.service")
+  def test_rearrange_kwargs_based_on_discovery(self, mock_service):
+    mock_service.return_value = self.MockDiscoveryService()
+
+    identifier = "test_schematransform"
+    expansion_service = "test_service"
+    kwargs = {"int_field": 0, "str_field": "str"}
+
+    transform = beam.SchemaAwareExternalTransform(
+        identifier=identifier, expansion_service=expansion_service, **kwargs)
+    ordered_kwargs = transform._rearrange_kwargs(identifier)
+
+    schematransform_config = beam.SchemaAwareExternalTransform.discover_config(
+        expansion_service, identifier)
+    external_config_fields = 
schematransform_config.configuration_schema._fields
+
+    self.assertNotEqual(tuple(kwargs.keys()), external_config_fields)
+    self.assertEqual(tuple(ordered_kwargs.keys()), external_config_fields)
+
 
 class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
   def _verify_row(self, schema, row_payload, expected_values):

Reply via email to