This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 210ede6  [BEAM-12769] Java-emulating external transform. (#15546)
210ede6 is described below

commit 210ede679845a68039e7792b6bd19f8b488efb24
Author: Robert Bradshaw <rober...@google.com>
AuthorDate: Fri Sep 24 17:38:30 2021 -0700

    [BEAM-12769] Java-emulating external transform. (#15546)
---
 sdks/python/apache_beam/transforms/external.py     | 46 ++++++++++++++++++++--
 .../python/apache_beam/transforms/external_test.py | 37 +++++++++++++++++
 2 files changed, 79 insertions(+), 4 deletions(-)

diff --git a/sdks/python/apache_beam/transforms/external.py 
b/sdks/python/apache_beam/transforms/external.py
index 16840c4..f8e6ddc 100644
--- a/sdks/python/apache_beam/transforms/external.py
+++ b/sdks/python/apache_beam/transforms/external.py
@@ -236,8 +236,7 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
     :param args: parameter values of the constructor.
     :param kwargs: parameter names and values of the constructor.
     """
-    if (self._constructor_method or self._constructor_param_args or
-        self._constructor_param_kwargs):
+    if self._has_constructor():
       raise ValueError(
           'Constructor or constructor method can only be specified once')
 
@@ -254,8 +253,7 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
     :param args: parameter values of the constructor method.
     :param kwargs: parameter names and values of the constructor method.
     """
-    if (self._constructor_method or self._constructor_param_args or
-        self._constructor_param_kwargs):
+    if self._has_constructor():
       raise ValueError(
           'Constructor or constructor method can only be specified once')
 
@@ -276,6 +274,46 @@ class JavaClassLookupPayloadBuilder(PayloadBuilder):
     """
     self._builder_methods_and_params[method_name] = (args, kwargs)
 
+  def _has_constructor(self):
+    return (
+        self._constructor_method or self._constructor_param_args or
+        self._constructor_param_kwargs)
+
+
+class JavaExternalTransform(ptransform.PTransform):
+  """A proxy for Java-implemented external transforms.
+
+  One builds these transforms just as one would in Java.
+  """
+  def __init__(self, class_name, expansion_service=None):
+    self._payload_builder = JavaClassLookupPayloadBuilder(class_name)
+    self._expansion_service = None
+
+  def __call__(self, *args, **kwargs):
+    self._payload_builder.with_constructor(*args, **kwargs)
+    return self
+
+  def __getattr__(self, name):
+    # Don't try to emulate special methods.
+    if name.startswith('__') and name.endswith('__'):
+      return super().__getattr__(name)
+
+    def construct(*args, **kwargs):
+      if self._payload_builder._has_constructor():
+        builder_method = self._payload_builder.add_builder_method
+      else:
+        builder_method = self._payload_builder.with_constructor_method
+      builder_method(name, *args, **kwargs)
+      return self
+
+    return construct
+
+  def expand(self, pcolls):
+    return pcolls | ExternalTransform(
+        common_urns.java_class_lookup,
+        self._payload_builder.build(),
+        self._expansion_service)
+
 
 class AnnotationBasedPayloadBuilder(SchemaBasedPayloadBuilder):
   """
diff --git a/sdks/python/apache_beam/transforms/external_test.py 
b/sdks/python/apache_beam/transforms/external_test.py
index 8673ed5..a528384 100644
--- a/sdks/python/apache_beam/transforms/external_test.py
+++ b/sdks/python/apache_beam/transforms/external_test.py
@@ -39,6 +39,7 @@ from apache_beam.testing.util import equal_to
 from apache_beam.transforms.external import AnnotationBasedPayloadBuilder
 from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
 from apache_beam.transforms.external import JavaClassLookupPayloadBuilder
+from apache_beam.transforms.external import JavaExternalTransform
 from apache_beam.transforms.external import NamedTupleBasedPayloadBuilder
 from apache_beam.typehints import typehints
 from apache_beam.typehints.native_type_compatibility import 
convert_to_beam_type
@@ -509,6 +510,42 @@ class JavaClassLookupPayloadBuilderTest(unittest.TestCase):
     with self.assertRaises(ValueError):
       payload_builder.with_constructor('def')
 
+  def test_implicit_builder_with_constructor(self):
+    constructor_transform = (
+        JavaExternalTransform('org.pkg.MyTransform')('abc').withIntProperty(5))
+
+    payload_bytes = constructor_transform._payload_builder.payload()
+    payload_from_bytes = proto_utils.parse_Bytes(
+        payload_bytes, JavaClassLookupPayload)
+    self.assertEqual('org.pkg.MyTransform', payload_from_bytes.class_name)
+    self._verify_row(
+        payload_from_bytes.constructor_schema,
+        payload_from_bytes.constructor_payload, {'ignore0': 'abc'})
+    builder_method = payload_from_bytes.builder_methods[0]
+    self.assertEqual('withIntProperty', builder_method.name)
+    self._verify_row(
+        builder_method.schema, builder_method.payload, {'ignore0': 5})
+
+  def test_implicit_builder_with_constructor_method(self):
+    constructor_transform = JavaExternalTransform('org.pkg.MyTransform').of(
+        str_field='abc').withProperty(int_field=1234).build()
+
+    payload_bytes = constructor_transform._payload_builder.payload()
+    payload_from_bytes = proto_utils.parse_Bytes(
+        payload_bytes, JavaClassLookupPayload)
+    self.assertEqual('of', payload_from_bytes.constructor_method)
+    self._verify_row(
+        payload_from_bytes.constructor_schema,
+        payload_from_bytes.constructor_payload, {'str_field': 'abc'})
+    with_property_method = payload_from_bytes.builder_methods[0]
+    self.assertEqual('withProperty', with_property_method.name)
+    self._verify_row(
+        with_property_method.schema,
+        with_property_method.payload, {'int_field': 1234})
+    build_method = payload_from_bytes.builder_methods[1]
+    self.assertEqual('build', build_method.name)
+    self._verify_row(build_method.schema, build_method.payload, {})
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

Reply via email to