This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new f2d6127758b Add an affinity concept to yaml providers. (#27105)
f2d6127758b is described below

commit f2d6127758b8809f3015235c444a70996a2de506
Author: Robert Bradshaw <[email protected]>
AuthorDate: Thu Jun 22 17:56:13 2023 -0700

    Add an affinity concept to yaml providers. (#27105)
    
    This will allow for selection of more-likely-to-fuse implementations
    for adjacent operations when more than one provider services the same
    operation.
---
 sdks/python/apache_beam/yaml/readme_test.py        |   5 +-
 sdks/python/apache_beam/yaml/yaml_mapping.py       |   4 +-
 sdks/python/apache_beam/yaml/yaml_provider.py      |  28 ++++-
 sdks/python/apache_beam/yaml/yaml_transform.py     |  52 ++++++++-
 .../python/apache_beam/yaml/yaml_transform_test.py | 128 +++++++++++++++++++++
 5 files changed, 205 insertions(+), 12 deletions(-)

diff --git a/sdks/python/apache_beam/yaml/readme_test.py 
b/sdks/python/apache_beam/yaml/readme_test.py
index 3d632682014..cd564cd3fe4 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -219,15 +219,14 @@ def parse_test_methods(markdown_lines):
         if code_lines:
           if code_lines[0].startswith('- type:'):
             # Treat this as a fragment of a larger pipeline.
+            # pylint: disable=not-an-iterable
             code_lines = [
                 'pipeline:',
                 '  type: chain',
                 '  transforms:',
                 '    - type: ReadFromCsv',
                 '      path: whatever',
-            ] + [
-                '    ' + line for line in code_lines  # pylint: 
disable=not-an-iterable
-            ]
+            ] + ['    ' + line for line in code_lines]
           if code_lines[0] == 'pipeline:':
             yaml_pipeline = '\n'.join(code_lines)
             if 'providers:' in yaml_pipeline:
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py 
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index 7f959773320..3b25eb78a39 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -199,7 +199,7 @@ def MapToFields(
 
     result = pcoll | yaml_create_transform({
         'type': 'Sql', 'query': query, **language_keywords
-    })
+    }, [pcoll])
     if explode:
       # TODO(yaml): Implement via unnest.
       result = result | _Explode(explode, cross_product)
@@ -217,7 +217,7 @@ def MapToFields(
             'cross_product': cross_product,
         },
         **language_keywords
-    })
+    }, [pcoll])
 
   else:
     # TODO(yaml): Support javascript expressions and UDFs.
diff --git a/sdks/python/apache_beam/yaml/yaml_provider.py 
b/sdks/python/apache_beam/yaml/yaml_provider.py
index 209e2178b8e..d58cb37a881 100644
--- a/sdks/python/apache_beam/yaml/yaml_provider.py
+++ b/sdks/python/apache_beam/yaml/yaml_provider.py
@@ -63,12 +63,32 @@ class Provider:
       self,
       typ: str,
       args: Mapping[str, Any],
-      yaml_create_transform: Callable[[Mapping[str, Any]], beam.PTransform]
+      yaml_create_transform: Callable[
+          [Mapping[str, Any], Iterable[beam.PCollection]], beam.PTransform]
   ) -> beam.PTransform:
     """Creates a PTransform instance for the given transform type and 
arguments.
     """
     raise NotImplementedError(type(self))
 
+  def affinity(self, other: "Provider"):
+    """Returns a value approximating how good it would be for this provider
+    to be used immediately following a transform from the other provider
+    (e.g. to encourage fusion).
+    """
+    # TODO(yaml): This is a very rough heuristic. Consider doing better.
+    # E.g. we could look at the the expected environments themselves.
+    # Possibly, we could provide multiple expansions and have the runner itself
+    # choose the actual implementation based on fusion (and other) criteria.
+    return self._affinity(other) + other._affinity(self)
+
+  def _affinity(self, other: "Provider"):
+    if self is other or self == other:
+      return 100
+    elif type(self) == type(other):
+      return 10
+    else:
+      return 0
+
 
 def as_provider(name, provider_or_constructor):
   if isinstance(provider_or_constructor, Provider):
@@ -201,6 +221,12 @@ class ExternalPythonProvider(ExternalProvider):
         }).payload(),
         self._service)
 
+  def _affinity(self, other: "Provider"):
+    if isinstance(other, InlineProvider):
+      return 50
+    else:
+      return super()._affinity(other)
+
 
 # This is needed because type inference can't handle *args, **kwargs 
forwarding.
 # TODO(BEAM-24755): Add support for type inference of through kwargs calls.
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index 925aa0d85b9..ebc9eb6c066 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -162,12 +162,13 @@ class LightweightScope(object):
 
 class Scope(LightweightScope):
   """To look up PCollections (typically outputs of prior transforms) by 
name."""
-  def __init__(self, root, inputs, transforms, providers):
+  def __init__(self, root, inputs, transforms, providers, input_providers):
     super().__init__(transforms)
     self.root = root
     self._inputs = inputs
     self.providers = providers
     self._seen_names = set()
+    self.input_providers = input_providers
 
   def compute_all(self):
     for transform_id in self._transforms_by_uuid.keys():
@@ -203,7 +204,7 @@ class Scope(LightweightScope):
     return expand_transform(self._transforms_by_uuid[transform_id], self)
 
   # A method on scope as providers may be scoped...
-  def create_ptransform(self, spec):
+  def create_ptransform(self, spec, input_pcolls):
     if 'type' not in spec:
       raise ValueError(f'Missing transform type: {identify_object(spec)}')
 
@@ -212,7 +213,20 @@ class Scope(LightweightScope):
           'Unknown transform type %r at %s' %
           (spec['type'], identify_object(spec)))
 
-    for provider in self.providers.get(spec['type']):
+    # TODO(yaml): Perhaps we can do better than a greedy choice here.
+    # TODO(yaml): Figure out why this is needed.
+    providers_by_input = {k: v for k, v in self.input_providers.items()}
+    input_providers = [
+        providers_by_input[pcoll] for pcoll in input_pcolls
+        if pcoll in providers_by_input
+    ]
+
+    def provider_score(p):
+      return sum(p.affinity(o) for o in input_providers)
+
+    for provider in sorted(self.providers.get(spec['type']),
+                           key=provider_score,
+                           reverse=True):
       if provider.available():
         break
     else:
@@ -245,6 +259,26 @@ class Scope(LightweightScope):
           yaml_provider=json.dumps(provider.to_json()),
           **ptransform.annotations())
       ptransform.annotations = lambda: annotations
+      original_expand = ptransform.expand
+
+      def recording_expand(pvalue):
+        result = original_expand(pvalue)
+
+        def record_providers(pvalueish):
+          if isinstance(pvalueish, (tuple, list)):
+            for p in pvalueish:
+              record_providers(p)
+          elif isinstance(pvalueish, dict):
+            for p in pvalueish.values():
+              record_providers(p)
+          elif isinstance(pvalueish, beam.PCollection):
+            if pvalueish not in self.input_providers:
+              self.input_providers[pvalueish] = provider
+
+        record_providers(result)
+        return result
+
+      ptransform.expand = recording_expand
       return ptransform
     except Exception as exn:
       if isinstance(exn, TypeError):
@@ -303,7 +337,7 @@ def expand_leaf_transform(spec, scope):
     else:
       inputs = inputs_dict
   _LOGGER.info("Expanding %s ", identify_object(spec))
-  ptransform = scope.create_ptransform(spec)
+  ptransform = scope.create_ptransform(spec, inputs_dict.values())
   try:
     # TODO: Move validation to construction?
     with FullyQualifiedNamedTransform.with_filter('*'):
@@ -336,7 +370,8 @@ def expand_composite_transform(spec, scope):
       spec['transforms'],
       yaml_provider.merge_providers(
           yaml_provider.parse_providers(spec.get('providers', [])),
-          scope.providers))
+          scope.providers),
+      scope.input_providers)
 
   class CompositePTransform(beam.PTransform):
     @staticmethod
@@ -667,7 +702,12 @@ class YamlTransform(beam.PTransform):
       root = next(iter(pcolls.values())).pipeline
     result = expand_transform(
         self._spec,
-        Scope(root, pcolls, transforms=[], providers=self._providers))
+        Scope(
+            root,
+            pcolls,
+            transforms=[],
+            providers=self._providers,
+            input_providers={}))
     if len(result) == 1:
       return only_element(result.values())
     else:
diff --git a/sdks/python/apache_beam/yaml/yaml_transform_test.py 
b/sdks/python/apache_beam/yaml/yaml_transform_test.py
index 167e3261d41..b036ea0a037 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform_test.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 #
 
+import collections
 import glob
 import logging
 import os
@@ -24,6 +25,7 @@ import unittest
 import apache_beam as beam
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
+from apache_beam.yaml import yaml_provider
 from apache_beam.yaml.yaml_transform import YamlTransform
 
 
@@ -323,6 +325,132 @@ class YamlWindowingTest(unittest.TestCase):
       assert_that(result, equal_to([6, 9]))
 
 
+class AnnotatingProvider(yaml_provider.InlineProvider):
+  """A provider that vends transforms that do nothing but record that this
+  provider (as identified by name) was used, along with any prior history
+  of the given element.
+  """
+  def __init__(self, name, transform_names):
+    super().__init__({
+        transform_name: lambda: beam.Map(lambda x: (x or ()) + (name, ))
+        for transform_name in transform_names.strip().split()
+    })
+    self._name = name
+
+  def __repr__(self):
+    return 'AnnotatingProvider(%r)' % self._name
+
+
+class AnotherAnnProvider(AnnotatingProvider):
+  """A Provider that behaves exactly as AnnotatingProvider, but is not
+  of the same type and so is considered "more distant" for matching purposes.
+  """
+  pass
+
+
+class ProviderAffinityTest(unittest.TestCase):
+  """These tests check that for a sequence of transforms, the "closest"
+  proveders are chosen among multiple possible implementations.
+  """
+  provider1 = AnnotatingProvider("provider1", "P1 A B C  ")
+  provider2 = AnnotatingProvider("provider2", "P2 A   C D")
+  provider3 = AnotherAnnProvider("provider3", "P3 A B    ")
+  provider4 = AnotherAnnProvider("provider4", "P4 A B   D")
+
+  providers_dict = collections.defaultdict(list)
+  for provider in [provider1, provider2, provider3, provider4]:
+    for transform_type in provider.provided_transforms():
+      providers_dict[transform_type].append(provider)
+
+  def test_prefers_same_provider(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result1 = p | 'Yaml1' >> YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              elements: [0]
+            - type: P1
+            - type: A
+            - type: C
+          ''',
+          providers=self.providers_dict)
+      assert_that(
+          result1,
+          equal_to([(
+              # provider1 was chosen, as it is the only one vending P1
+              'provider1',
+              # All of the providers vend A, but since the input was produced
+              # by provider1, we prefer to use that again.
+              'provider1',
+              # Similarly for C.
+              'provider1')]),
+          label='StartWith1')
+
+      result2 = p | 'Yaml2' >> YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              elements: [0]
+            - type: P2
+            - type: A
+            - type: C
+          ''',
+          providers=self.providers_dict)
+      assert_that(
+          result2,
+          equal_to([(
+              # provider2 was necessarily chosen for P2
+              'provider2',
+              # Unlike above, we choose provider2 to implement A.
+              'provider2',
+              # Likewise for C.
+              'provider2')]),
+          label='StartWith2')
+
+  def test_prefers_same_provider_class(self):
+    # Like test_prefers_same_provider, but as we cannot choose the same
+    # exact provider, we go with the next closest (which is of the same type)
+    # over an implementation from a Provider of a different type.
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      result1 = p | 'Yaml1' >> YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              elements: [0]
+            - type: P1
+            - type: A
+            - type: D
+            - type: A
+          ''',
+          providers=self.providers_dict)
+      assert_that(
+          result1,
+          equal_to([('provider1', 'provider1', 'provider2', 'provider2')]),
+          label='StartWith1')
+
+      result3 = p | 'Yaml2' >> YamlTransform(
+          '''
+          type: chain
+          transforms:
+            - type: Create
+              elements: [0]
+            - type: P3
+            - type: A
+            - type: D
+            - type: A
+          ''',
+          providers=self.providers_dict)
+      assert_that(
+          result3,
+          equal_to([('provider3', 'provider3', 'provider4', 'provider4')]),
+          label='StartWith3')
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to