This is an automated email from the ASF dual-hosted git repository.

robertwb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new e3fee5156b3 [YAML] Add Partition transform. (#30368)
e3fee5156b3 is described below

commit e3fee5156b3515f96dc5ba90ea2fbc6f6be2bd34
Author: Robert Bradshaw <rober...@gmail.com>
AuthorDate: Thu Mar 28 17:25:00 2024 -0700

    [YAML] Add Partition transform. (#30368)
---
 .../apache_beam/yaml/programming_guide_test.py     |  19 ++
 sdks/python/apache_beam/yaml/readme_test.py        |  30 +--
 sdks/python/apache_beam/yaml/yaml_mapping.py       |  88 +++++++-
 sdks/python/apache_beam/yaml/yaml_mapping_test.py  | 239 +++++++++++++++++++++
 sdks/python/apache_beam/yaml/yaml_transform.py     |   6 +-
 .../content/en/documentation/programming-guide.md  |  16 ++
 .../site/content/en/documentation/sdks/yaml-udf.md |  68 ++++++
 7 files changed, 451 insertions(+), 15 deletions(-)

diff --git a/sdks/python/apache_beam/yaml/programming_guide_test.py 
b/sdks/python/apache_beam/yaml/programming_guide_test.py
index cd7bf6a8814..fe5e242f7f5 100644
--- a/sdks/python/apache_beam/yaml/programming_guide_test.py
+++ b/sdks/python/apache_beam/yaml/programming_guide_test.py
@@ -404,6 +404,25 @@ class ProgrammingGuideTest(unittest.TestCase):
           # [END setting_timestamp]
           ''')
 
+  def test_partition(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(percentile=1),
+          beam.Row(percentile=20),
+          beam.Row(percentile=90),
+      ])
+      _ = elements | YamlTransform(
+          '''
+          # [START model_multiple_pcollections_partition]
+          type: Partition
+          config:
+            by: str(percentile // 10)
+            language: python
+            outputs: ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
+          # [END model_multiple_pcollections_partition]
+          ''')
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/yaml/readme_test.py 
b/sdks/python/apache_beam/yaml/readme_test.py
index 4ca60e6176b..ea7a015dab5 100644
--- a/sdks/python/apache_beam/yaml/readme_test.py
+++ b/sdks/python/apache_beam/yaml/readme_test.py
@@ -128,12 +128,25 @@ class FakeAggregation(beam.PTransform):
         lambda _: 1, sum, 'count')
 
 
+class _Fakes:
+  fn = str
+
+  class SomeTransform(beam.PTransform):
+    def __init__(*args, **kwargs):
+      pass
+
+    def expand(self, pcoll):
+      return pcoll
+
+
 RENDER_DIR = None
 TEST_TRANSFORMS = {
     'Sql': FakeSql,
     'ReadFromPubSub': FakeReadFromPubSub,
     'WriteToPubSub': FakeWriteToPubSub,
     'SomeGroupingTransform': FakeAggregation,
+    'SomeTransform': _Fakes.SomeTransform,
+    'AnotherTransform': _Fakes.SomeTransform,
 }
 
 
@@ -155,7 +168,7 @@ class TestEnvironment:
     return path
 
   def input_csv(self):
-    return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
+    return self.input_file('input.csv', 'col1,col2,col3\na,1,2.5\n')
 
   def input_tsv(self):
     return self.input_file('input.tsv', 'col1\tcol2\tcol3\nabc\t1\t2.5\n')
@@ -250,13 +263,15 @@ def parse_test_methods(markdown_lines):
       else:
         if code_lines:
           if code_lines[0].startswith('- type:'):
+            is_chain = not any('input:' in line for line in code_lines)
             # Treat this as a fragment of a larger pipeline.
             # pylint: disable=not-an-iterable
             code_lines = [
                 'pipeline:',
-                '  type: chain',
+                '  type: chain' if is_chain else '',
                 '  transforms:',
                 '    - type: ReadFromCsv',
+                '      name: input',
                 '      config:',
                 '        path: whatever',
             ] + ['    ' + line for line in code_lines]
@@ -278,17 +293,6 @@ def createTestSuite(name, path):
     return type(name, (unittest.TestCase, ), dict(parse_test_methods(readme)))
 
 
-class _Fakes:
-  fn = str
-
-  class SomeTransform(beam.PTransform):
-    def __init__(*args, **kwargs):
-      pass
-
-    def expand(self, pcoll):
-      return pcoll
-
-
 # These are copied from $ROOT/website/www/site/content/en/documentation/sdks
 # at build time.
 YAML_DOCS_DIR = os.path.join(os.path.join(os.path.dirname(__file__), 'docs'))
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping.py 
b/sdks/python/apache_beam/yaml/yaml_mapping.py
index 954e32cdf7b..4839728dd88 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping.py
@@ -24,6 +24,7 @@ from typing import Any
 from typing import Callable
 from typing import Collection
 from typing import Dict
+from typing import List
 from typing import Mapping
 from typing import NamedTuple
 from typing import Optional
@@ -42,6 +43,7 @@ from apache_beam.transforms.window import TimestampedValue
 from apache_beam.typehints import row_type
 from apache_beam.typehints import schemas
 from apache_beam.typehints import trivial_inference
+from apache_beam.typehints import typehints
 from apache_beam.typehints.row_type import RowTypeConstraint
 from apache_beam.typehints.schemas import named_fields_from_element_type
 from apache_beam.utils import python_callable
@@ -569,6 +571,86 @@ def _SqlMapToFieldsTransform(pcoll, 
sql_transform_constructor, **mapping_args):
   return pcoll | sql_transform_constructor(query)
 
 
+@beam.ptransform.ptransform_fn
+def _Partition(
+    pcoll,
+    by: Union[str, Dict[str, str]],
+    outputs: List[str],
+    unknown_output: Optional[str] = None,
+    error_handling: Optional[Mapping[str, Any]] = None,
+    language: Optional[str] = 'generic'):
+  """Splits an input into several distinct outputs.
+
+  Each input element will go to a distinct output based on the field or
+  function given in the `by` configuration parameter.
+
+  Args:
+      by: A field, callable, or expression giving the destination output for
+        this element.  Should return a string that is a member of the `outputs`
+        parameter. If `unknown_output` is also set, other returns values are
+        accepted as well, otherwise an error will be raised.
+      outputs: The set of outputs into which this input is being partitioned.
+      unknown_output: (Optional) If set, indicates a destination output for any
+        elements that are not assigned an output listed in the `outputs`
+        parameter.
+      error_handling: (Optional) Whether and how to handle errors during
+        partitioning.
+      language: (Optional) The language of the `by` expression.
+  """
+  split_fn = _as_callable_for_pcoll(pcoll, by, 'by', language)
+  try:
+    split_fn_output_type = trivial_inference.infer_return_type(
+        split_fn, [pcoll.element_type])
+  except (TypeError, ValueError):
+    pass
+  else:
+    if not typehints.is_consistent_with(split_fn_output_type,
+                                        typehints.Optional[str]):
+      raise ValueError(
+          f'Partition function "{by}" must return a string type '
+          f'not {split_fn_output_type}')
+  error_output = error_handling['output'] if error_handling else None
+  if error_output in outputs:
+    raise ValueError(
+        f'Error handling output "{error_output}" '
+        f'cannot be among the listed outputs {outputs}')
+  T = TypeVar('T')
+
+  def split(element):
+    tag = split_fn(element)
+    if tag is None:
+      tag = unknown_output
+    if not isinstance(tag, str):
+      raise ValueError(
+          f'Returned output name "{tag}" of type {type(tag)} '
+          f'from "{by}" must be a string.')
+    if tag not in outputs:
+      if unknown_output:
+        tag = unknown_output
+      else:
+        raise ValueError(f'Unknown output name "{tag}" from {by}')
+    return beam.pvalue.TaggedOutput(tag, element)
+
+  output_set = set(outputs)
+  if unknown_output:
+    output_set.add(unknown_output)
+  if error_output:
+    output_set.add(error_output)
+  mapping_transform = beam.Map(split)
+  if error_output:
+    mapping_transform = mapping_transform.with_exception_handling(
+        **exception_handling_args(error_handling))
+  else:
+    mapping_transform = mapping_transform.with_outputs(*output_set)
+  splits = pcoll | mapping_transform.with_input_types(T).with_output_types(T)
+  result = {out: getattr(splits, out) for out in output_set}
+  if error_output:
+    result[
+        error_output] = result[error_output] | _map_errors_to_standard_format(
+            pcoll.element_type)
+  return result
+
+
 @beam.ptransform.ptransform_fn
 @maybe_with_exception_handling_transform_fn
 def _AssignTimestamps(
@@ -588,7 +670,8 @@ def _AssignTimestamps(
   Args:
       timestamp: A field, callable, or expression giving the new timestamp.
       language: The language of the timestamp expression.
-      error_handling: Whether and how to handle errors during iteration.
+      error_handling: Whether and how to handle errors during timestamp
+        evaluation.
   """
   timestamp_fn = _as_callable_for_pcoll(pcoll, timestamp, 'timestamp', 
language)
   T = TypeVar('T')
@@ -611,6 +694,9 @@ def create_mapping_providers():
           'MapToFields-python': _PyJsMapToFields,
           'MapToFields-javascript': _PyJsMapToFields,
           'MapToFields-generic': _PyJsMapToFields,
+          'Partition-python': _Partition,
+          'Partition-javascript': _Partition,
+          'Partition-generic': _Partition,
       }),
       yaml_provider.SqlBackedProvider({
           'Filter-sql': _SqlFilterTransform,
diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py 
b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
index 9dca107dca5..d5aa4038ef7 100644
--- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py
+++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py
@@ -151,6 +151,245 @@ class YamlMappingTest(unittest.TestCase):
             ''')
         self.assertEqual(result.element_type._fields[0][1], str)
 
+  def test_partition(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple'),
+          beam.Row(element='banana'),
+          beam.Row(element='orange'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by: "'even' if len(element) % 2 == 0 else 'odd'"
+            language: python
+            outputs: [even, odd]
+          ''')
+      assert_that(
+          result['even'] | beam.Map(lambda x: x.element),
+          equal_to(['banana', 'orange']),
+          label='Even')
+      assert_that(
+          result['odd'] | beam.Map(lambda x: x.element),
+          equal_to(['apple']),
+          label='Odd')
+
+  def test_partition_callable(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple'),
+          beam.Row(element='banana'),
+          beam.Row(element='orange'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by:
+              callable:
+                "lambda row: 'even' if len(row.element) % 2 == 0 else 'odd'"
+            language: python
+            outputs: [even, odd]
+          ''')
+      assert_that(
+          result['even'] | beam.Map(lambda x: x.element),
+          equal_to(['banana', 'orange']),
+          label='Even')
+      assert_that(
+          result['odd'] | beam.Map(lambda x: x.element),
+          equal_to(['apple']),
+          label='Odd')
+
+  def test_partition_with_unknown(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple'),
+          beam.Row(element='banana'),
+          beam.Row(element='orange'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by: "element.lower()[0]"
+            language: python
+            outputs: [a, b, c]
+            unknown_output: other
+          ''')
+      assert_that(
+          result['a'] | beam.Map(lambda x: x.element),
+          equal_to(['apple']),
+          label='A')
+      assert_that(
+          result['b'] | beam.Map(lambda x: x.element),
+          equal_to(['banana']),
+          label='B')
+      assert_that(
+          result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C')
+      assert_that(
+          result['other'] | beam.Map(lambda x: x.element),
+          equal_to(['orange']),
+          label='Other')
+
+  def test_partition_without_unknown(self):
+    with self.assertRaisesRegex(ValueError, r'.*Unknown output name.*"o".*'):
+      with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+          pickle_library='cloudpickle')) as p:
+        elements = p | beam.Create([
+            beam.Row(element='apple'),
+            beam.Row(element='banana'),
+            beam.Row(element='orange'),
+        ])
+        _ = elements | YamlTransform(
+            '''
+            type: Partition
+            input: input
+            config:
+              by: "element.lower()[0]"
+              language: python
+              outputs: [a, b, c]
+            ''')
+
+  def test_partition_without_unknown_with_error(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple'),
+          beam.Row(element='banana'),
+          beam.Row(element='orange'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by: "element.lower()[0]"
+            language: python
+            outputs: [a, b, c]
+            error_handling:
+              output: unknown
+          ''')
+      assert_that(
+          result['a'] | beam.Map(lambda x: x.element),
+          equal_to(['apple']),
+          label='A')
+      assert_that(
+          result['b'] | beam.Map(lambda x: x.element),
+          equal_to(['banana']),
+          label='B')
+      assert_that(
+          result['c'] | beam.Map(lambda x: x.element), equal_to([]), label='C')
+      assert_that(
+          result['unknown'] | beam.Map(lambda x: x.element.element),
+          equal_to(['orange']),
+          label='Errors')
+
+  def test_partition_with_actual_error(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple'),
+          beam.Row(element='banana'),
+          beam.Row(element='orange'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by: "element.lower()[5]"
+            language: python
+            outputs: [a, b, c]
+            unknown_output: other
+            error_handling:
+              output: errors
+          ''')
+      assert_that(
+          result['a'] | beam.Map(lambda x: x.element),
+          equal_to(['banana']),
+          label='B')
+      assert_that(
+          result['other'] | beam.Map(lambda x: x.element),
+          equal_to(['orange']),
+          label='Other')
+      # Apple only has 5 letters, resulting in an index error.
+      assert_that(
+          result['errors'] | beam.Map(lambda x: x.element.element),
+          equal_to(['apple']),
+          label='Errors')
+
+  def test_partition_no_language(self):
+    with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+        pickle_library='cloudpickle')) as p:
+      elements = p | beam.Create([
+          beam.Row(element='apple', texture='smooth'),
+          beam.Row(element='banana', texture='smooth'),
+          beam.Row(element='orange', texture='bumpy'),
+      ])
+      result = elements | YamlTransform(
+          '''
+          type: Partition
+          input: input
+          config:
+            by: texture
+            outputs: [bumpy, smooth]
+          ''')
+      assert_that(
+          result['bumpy'] | beam.Map(lambda x: x.element),
+          equal_to(['orange']),
+          label='Bumpy')
+      assert_that(
+          result['smooth'] | beam.Map(lambda x: x.element),
+          equal_to(['apple', 'banana']),
+          label='Smooth')
+
+  def test_partition_bad_static_type(self):
+    with self.assertRaisesRegex(
+        ValueError, r'.*Partition function .*must return a string.*'):
+      with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+          pickle_library='cloudpickle')) as p:
+        elements = p | beam.Create([
+            beam.Row(element='apple', texture='smooth'),
+            beam.Row(element='banana', texture='smooth'),
+            beam.Row(element='orange', texture='bumpy'),
+        ])
+        _ = elements | YamlTransform(
+            '''
+            type: Partition
+            input: input
+            config:
+              by: len(texture)
+              outputs: [bumpy, smooth]
+              language: python
+            ''')
+
+  def test_partition_bad_runtime_type(self):
+    with self.assertRaisesRegex(ValueError,
+                                r'.*Returned output name.*must be a string.*'):
+      with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
+          pickle_library='cloudpickle')) as p:
+        elements = p | beam.Create([
+            beam.Row(element='apple', texture='smooth'),
+            beam.Row(element='banana', texture='smooth'),
+            beam.Row(element='orange', texture='bumpy'),
+        ])
+        _ = elements | YamlTransform(
+            '''
+            type: Partition
+            input: input
+            config:
+              by: print(texture)
+              outputs: [bumpy, smooth]
+              language: python
+            ''')
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
diff --git a/sdks/python/apache_beam/yaml/yaml_transform.py 
b/sdks/python/apache_beam/yaml/yaml_transform.py
index 03574b5f98f..df2fdbf6aaa 100644
--- a/sdks/python/apache_beam/yaml/yaml_transform.py
+++ b/sdks/python/apache_beam/yaml/yaml_transform.py
@@ -918,7 +918,11 @@ def preprocess(spec, verbose=False, known_transforms=None):
     return spec
 
   def preprocess_langauges(spec):
-    if spec['type'] in ('Filter', 'MapToFields', 'Combine', 
'AssignTimestamps'):
+    if spec['type'] in ('AssignTimestamps',
+                        'Combine',
+                        'Filter',
+                        'MapToFields',
+                        'Partition'):
       language = spec.get('config', {}).get('language', 'generic')
       new_type = spec['type'] + '-' + language
       if known_transforms and new_type not in known_transforms:
diff --git a/website/www/site/content/en/documentation/programming-guide.md 
b/website/www/site/content/en/documentation/programming-guide.md
index 4c51d99ce65..b228dac1909 100644
--- a/website/www/site/content/en/documentation/programming-guide.md
+++ b/website/www/site/content/en/documentation/programming-guide.md
@@ -2153,6 +2153,14 @@ students = ...
 {{< code_sample "sdks/typescript/test/docs/programming_guide.ts" 
model_multiple_pcollections_partition >}}
 {{< /highlight >}}
 
+{{< highlight yaml >}}
+{{< code_sample "sdks/python/apache_beam/yaml/programming_guide_test.py" 
model_multiple_pcollections_partition >}}
+{{< /highlight >}}
+
+{{< paragraph class="language-yaml">}}
+Note that in Beam YAML, `PCollections` are partitioned via string rather than 
integer values.
+{{< /paragraph >}}
+
 ### 4.3. Requirements for writing user code for Beam transforms 
{#requirements-for-writing-user-code-for-beam-transforms}
 
 When you build user code for a Beam transform, you should keep in mind the
@@ -2415,6 +2423,14 @@ properties in your `ParDo` operation and follow this 
operation with a `Split`
 to break it into multiple `PCollection`s.
 {{< /paragraph >}}
 
+{{< paragraph class="language-yaml">}}
+In Beam YAML, one obtains multiple outputs by emitting all outputs to a single
+`PCollection`, possibly with an extra field, and then using `Partition` to
+split this single `PCollection` into multiple distinct `PCollection`
+outputs.
+{{< /paragraph >}}
+
+
 #### 4.5.1. Tags for multiple outputs {#output-tags}
 
 {{< paragraph class="language-typescript">}}
diff --git a/website/www/site/content/en/documentation/sdks/yaml-udf.md 
b/website/www/site/content/en/documentation/sdks/yaml-udf.md
index c2ab3eb6460..5a51f1af1a1 100644
--- a/website/www/site/content/en/documentation/sdks/yaml-udf.md
+++ b/website/www/site/content/en/documentation/sdks/yaml-udf.md
@@ -207,6 +207,74 @@ criteria. This can be accomplished with a `Filter` 
transform, e.g.
     keep: "col2 > 0"
 ```
 
+## Partitioning
+
+It can also be useful to send different elements to different places
+(similar to what is done with side outputs in other SDKs).
+While this can be done with a set of `Filter` operations, if every
+element has a single destination it can be more natural to use a `Partition`
+transform instead which sends every element to a unique output.
+For example, this will send all elements where `col1` is equal to `"a"` to the
+output `Partition.a`.
+
+```
+- type: Partition
+  input: input
+  config:
+    by: col1
+    outputs: ['a', 'b', 'c']
+
+- type: SomeTransform
+  input: Partition.a
+  config:
+    param: ...
+
+- type: AnotherTransform
+  input: Partition.b
+  config:
+    param: ...
+```
+
+One can also specify the destination as a function, e.g.
+
+```
+- type: Partition
+  input: input
+  config:
+    by: "'even' if col2 % 2 == 0 else 'odd'"
+    language: python
+    outputs: ['even', 'odd']
+```
+
+One can optionally provide a catch-all output which will capture all elements
+that are not in the named outputs (which would otherwise be an error):
+
+```
+- type: Partition
+  input: input
+  config:
+    by: col1
+    outputs: ['a', 'b', 'c']
+    unknown_output: 'other'
+```
+
+Sometimes one wants to split a PCollection into multiple PCollections
+that aren't necessarily disjoint.  To send elements to multiple (or no) 
outputs,
+one could use an iterable column and precede the `Partition` with an `Explode`.
+
+```
+- type: Explode
+  input: input
+  config:
+    fields: col1
+
+- type: Partition
+  input: Explode
+  config:
+    by: col1
+    outputs: ['a', 'b', 'c']
+```
+
 ## Types
 
 Beam will try to infer the types involved in the mappings, but sometimes this

Reply via email to