chamikaramj commented on code in PR #24667:
URL: https://github.com/apache/beam/pull/24667#discussion_r1084498967


##########
sdks/python/apache_beam/yaml/yaml_provider.py:
##########
@@ -0,0 +1,420 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines Providers usable from yaml, which is a specification
+for where to find and how to invoke services that vend implementations of
+various PTransforms."""
+
+import collections
+import hashlib
+import json
+import os
+import subprocess
+import sys
+import uuid
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+import apache_beam.dataframe.io
+import apache_beam.io
+import apache_beam.transforms.util
+from apache_beam.portability.api import schema_pb2
+from apache_beam.transforms import external
+from apache_beam.transforms.fully_qualified_named_transform import 
FullyQualifiedNamedTransform
+from apache_beam.typehints import schemas
+from apache_beam.typehints import trivial_inference
+from apache_beam.utils import python_callable
+from apache_beam.utils import subprocess_server
+from apache_beam.version import __version__ as beam_version
+
+
+class Provider:
+  """Maps transform types to concrete PTransform instances."""
+  def available(self):
+    raise NotImplementedError(type(self))
+
+  def provided_transforms(self):
+    raise NotImplementedError(type(self))
+
+  def create_transform(self, typ, args):
+    raise NotImplementedError(type(self))
+
+
+class ExternalProvider(Provider):
+  """A Provider implemented via the cross language transform service."""
+  def __init__(self, urns, service):
+    self._urns = urns
+    self._service = service
+    self._schema_transforms = None
+
+  def provided_transforms(self):
+    return self._urns.keys()
+
+  def create_transform(self, type, args):
+    if callable(self._service):
+      self._service = self._service()
+    if self._schema_transforms is None:
+      try:
+        self._schema_transforms = [
+            config.identifier
+            for config in external.SchemaAwareExternalTransform.discover(
+                self._service)
+        ]
+      except Exception:
+        self._schema_transforms = []
+    urn = self._urns[type]
+    if urn in self._schema_transforms:
+      return external.SchemaAwareExternalTransform(urn, self._service, **args)
+    else:
+      return type >> self.create_external_transform(urn, args)
+
+  def create_external_transform(self, urn, args):
+    return external.ExternalTransform(
+        urn,
+        external.ImplicitSchemaPayloadBuilder(args).payload(),
+        self._service)
+
+  @staticmethod
+  def provider_from_spec(spec):
+    urns = spec['transforms']
+    type = spec['type']
+    if spec.get('version', None) == 'BEAM_VERSION':
+      spec['version'] = beam_version
+    if type == 'jar':
+      return ExternalJavaProvider(urns, lambda: spec['jar'])
+    elif type == 'mavenJar':
+      return ExternalJavaProvider(
+          urns,
+          lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
+              **{
+                  key: value
+                  for (key, value) in spec.items() if key in [
+                      'artifact_id',
+                      'group_id',
+                      'version',
+                      'repository',
+                      'classifier',
+                      'appendix'
+                  ]
+              }))
+    elif type == 'beamJar':
+      return ExternalJavaProvider(
+          urns,
+          lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
+              **{
+                  key: value
+                  for (key, value) in spec.items() if key in
+                  ['gradle_target', 'version', 'appendix', 'artifact_id']
+              }))
+    elif type == 'pypi':
+      return ExternalPythonProvider(urns, spec['packages'])
+    elif type == 'remote':
+      return RemoteProvider(spec['address'])
+    elif type == 'docker':
+      raise NotImplementedError()
+    else:
+      raise NotImplementedError(f'Unknown provider type: {type}')
+
+
+class RemoteProvider(ExternalProvider):
+  _is_available = None
+
+  def available(self):
+    if self._is_available is None:
+      try:
+        with external.ExternalTransform.service(self._service) as service:
+          service.ready(1)
+          self._is_available = True
+      except Exception:
+        self._is_available = False
+    return self._is_available
+
+
+class ExternalJavaProvider(ExternalProvider):
+  def __init__(self, urns, jar_provider):
+    super().__init__(
+        urns, lambda: external.JavaJarExpansionService(jar_provider()))
+
+  def available(self):
+    # pylint: disable=subprocess-run-check
+    return subprocess.run(['which', 'java'],
+                          capture_output=True).returncode == 0
+
+
+class ExternalPythonProvider(ExternalProvider):
+  def __init__(self, urns, packages):
+    super().__init__(urns, PypiExpansionService(packages))
+
+  def available(self):
+    return True  # If we're running this script, we have Python installed.
+
+  def create_external_transform(self, urn, args):
+    # Python transforms are "registered" by fully qualified name.
+    return external.ExternalTransform(
+        "beam:transforms:python:fully_qualified_named",
+        external.ImplicitSchemaPayloadBuilder({
+            'constructor': urn,
+            'kwargs': args,
+        }).payload(),
+        self._service)
+
+
+# This is needed because type inference can't handle *args, **kwargs fowarding.
+# TODO: Fix Beam itself.
+def fix_pycallable():
+  from apache_beam.transforms.ptransform import label_from_callable
+
+  def default_label(self):
+    src = self._source.strip()
+    last_line = src.split('\n')[-1]
+    if last_line[0] != ' ' and len(last_line) < 72:
+      return last_line
+    return label_from_callable(self._callable)
+
+  def _argspec_fn(self):
+    return self._callable
+
+  python_callable.PythonCallableWithSource.default_label = default_label
+  python_callable.PythonCallableWithSource._argspec_fn = property(_argspec_fn)
+
+  original_infer_return_type = trivial_inference.infer_return_type
+
+  def infer_return_type(fn, *args, **kwargs):
+    if isinstance(fn, python_callable.PythonCallableWithSource):
+      fn = fn._callable
+    return original_infer_return_type(fn, *args, **kwargs)
+
+  trivial_inference.infer_return_type = infer_return_type
+
+  original_fn_takes_side_inputs = (
+      apache_beam.transforms.util.fn_takes_side_inputs)
+
+  def fn_takes_side_inputs(fn):
+    if isinstance(fn, python_callable.PythonCallableWithSource):
+      fn = fn._callable
+    return original_fn_takes_side_inputs(fn)
+
+  apache_beam.transforms.util.fn_takes_side_inputs = fn_takes_side_inputs
+
+
+class InlineProvider(Provider):
+  def __init__(self, transform_factories):
+    self._transform_factories = transform_factories
+
+  def available(self):
+    return True
+
+  def provided_transforms(self):
+    return self._transform_factories.keys()
+
+  def create_transform(self, type, args):
+    return self._transform_factories[type](**args)
+
+
+PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
+    py_type.__name__: schema_type
+    for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
+    if py_type.__module__ != 'typing'
+}
+
+
+def create_builtin_provider():
+  def with_schema(**args):
+    # TODO: This is preliminary.
+    def parse_type(spec):
+      if spec in PRIMITIVE_NAMES_TO_ATOMIC_TYPE:
+        return schema_pb2.FieldType(
+            atomic_type=PRIMITIVE_NAMES_TO_ATOMIC_TYPE[spec])
+      elif isinstance(spec, list):
+        if len(spec) != 1:
+          raise ValueError("Use single-element lists to denote list types.")
+        else:
+          return schema_pb2.FieldType(
+              iterable_type=schema_pb2.IterableType(
+                  element_type=parse_type(spec[0])))
+      elif isinstance(spec, dict):
+        return schema_pb2.FieldType(
+            iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0])))
+      else:
+        raise ValueError("Unknown schema type: {spec}")
+
+    def parse_schema(spec):
+      return schema_pb2.Schema(
+          fields=[
+              schema_pb2.Field(name=key, type=parse_type(value), id=ix)
+              for (ix, (key, value)) in enumerate(spec.items())
+          ],
+          id=str(uuid.uuid4()))
+
+    named_tuple = schemas.named_tuple_from_schema(parse_schema(args))
+    names = list(args.keys())
+
+    def extract_field(x, name):
+      if isinstance(x, dict):
+        return x[name]
+      else:
+        return getattr(x, name)
+
+    return 'WithSchema(%s)' % ', '.join(names) >> beam.Map(
+        lambda x: named_tuple(*[extract_field(x, name) for name in names])
+    ).with_output_types(named_tuple)
+
+  # Or should this be posargs, args?
+  # pylint: disable=dangerous-default-value
+  def fully_qualified_named_transform(constructor, args=(), kwargs={}):
+    with FullyQualifiedNamedTransform.with_filter('*'):
+      return constructor >> FullyQualifiedNamedTransform(
+          constructor, args, kwargs)
+
+  class Flatten(beam.PTransform):

Review Comment:
   I see. Let's add a comment here regarding needing the extra Flatten to 
support one or zero PCollections.



##########
sdks/python/apache_beam/yaml/yaml_provider.py:
##########
@@ -0,0 +1,420 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This module defines Providers usable from yaml, which is a specification
+for where to find and how to invoke services that vend implementations of
+various PTransforms."""
+
+import collections
+import hashlib
+import json
+import os
+import subprocess
+import sys
+import uuid
+import yaml
+from yaml.loader import SafeLoader
+
+import apache_beam as beam
+import apache_beam.dataframe.io
+import apache_beam.io
+import apache_beam.transforms.util
+from apache_beam.portability.api import schema_pb2
+from apache_beam.transforms import external
+from apache_beam.transforms.fully_qualified_named_transform import 
FullyQualifiedNamedTransform
+from apache_beam.typehints import schemas
+from apache_beam.typehints import trivial_inference
+from apache_beam.utils import python_callable
+from apache_beam.utils import subprocess_server
+from apache_beam.version import __version__ as beam_version
+
+
+class Provider:
+  """Maps transform types to concrete PTransform instances."""
+  def available(self):
+    raise NotImplementedError(type(self))
+
+  def provided_transforms(self):
+    raise NotImplementedError(type(self))
+
+  def create_transform(self, typ, args):
+    raise NotImplementedError(type(self))
+
+
+class ExternalProvider(Provider):
+  """A Provider implemented via the cross language transform service."""
+  def __init__(self, urns, service):
+    self._urns = urns
+    self._service = service
+    self._schema_transforms = None
+
+  def provided_transforms(self):
+    return self._urns.keys()
+
+  def create_transform(self, type, args):
+    if callable(self._service):
+      self._service = self._service()
+    if self._schema_transforms is None:
+      try:
+        self._schema_transforms = [
+            config.identifier
+            for config in external.SchemaAwareExternalTransform.discover(
+                self._service)
+        ]
+      except Exception:
+        self._schema_transforms = []
+    urn = self._urns[type]
+    if urn in self._schema_transforms:
+      return external.SchemaAwareExternalTransform(urn, self._service, **args)

Review Comment:
   I don't have a good place here but this should go to Yaml runner 
documentation when we have it.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to