This is an automated email from the ASF dual-hosted git repository. pabloem pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c9bad5d Changes to SDF API to use DoFn Params (#8430) c9bad5d is described below commit c9bad5d023ac4755e47e4b73f5d9a92e402b152c Author: Pablo <pabl...@users.noreply.github.com> AuthorDate: Wed May 8 14:05:33 2019 -0700 Changes to SDF API to use DoFn Params (#8430) * Changes to SDF API to use DoFn Params * Fix docs tests * Fix docs again * Fix test --- sdks/python/apache_beam/runners/common.py | 8 ++++---- .../runners/direct/sdf_direct_runner_test.py | 7 +++++-- .../runners/portability/fn_api_runner_test.py | 17 +++++++++++++--- .../apache_beam/testing/synthetic_pipeline.py | 3 ++- sdks/python/apache_beam/transforms/core.py | 23 ++++++++++++++++++++-- sdks/python/scripts/generate_pydoc.sh | 1 + 6 files changed, 47 insertions(+), 12 deletions(-) diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 84ac116..3bbfd90 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -239,8 +239,8 @@ class DoFnSignature(object): def get_restriction_provider(self): result = _find_param_with_default(self.process_method, - default_as_type=RestrictionProvider) - return result[1] if result else None + default_as_type=DoFn.RestrictionParam) + return result[1].restriction_provider if result else None def _validate(self): self._validate_process() @@ -271,7 +271,7 @@ class DoFnSignature(object): userstate.validate_stateful_dofn(self.do_fn) def is_splittable_dofn(self): - return any([isinstance(default, RestrictionProvider) for default in + return any([isinstance(default, DoFn.RestrictionParam) for default in self.process_method.defaults]) def is_stateful_dofn(self): @@ -538,7 +538,7 @@ class PerWindowInvoker(DoFnInvoker): 'SDFs in multiply-windowed values with windowed arguments.') restriction_tracker_param = _find_param_with_default( self.signature.process_method, - default_as_type=core.RestrictionProvider)[0] + default_as_type=DoFn.RestrictionParam)[0] if not restriction_tracker_param: raise ValueError( 'A RestrictionTracker %r was provided but DoFn does not have a ' diff --git a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py index eae38bc..3e1e344 100644 --- a/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py +++ b/sdks/python/apache_beam/runners/direct/sdf_direct_runner_test.py @@ -57,7 +57,10 @@ class ReadFiles(DoFn): self._resume_count = resume_count def process( - self, element, restriction_tracker=ReadFilesProvider(), *args, **kwargs): + self, + element, + restriction_tracker=DoFn.RestrictionParam(ReadFilesProvider()), + *args, **kwargs): file_name = element assert isinstance(restriction_tracker, OffsetRestrictionTracker) @@ -107,7 +110,7 @@ class ExpandStrings(DoFn): def process( self, element, side1, side2, side3, window=beam.DoFn.WindowParam, - restriction_tracker=ExpandStringsProvider(), + restriction_tracker=DoFn.RestrictionParam(ExpandStringsProvider()), *args, **kwargs): side = [] side.extend(side1) diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py index c584ef1..a807cfa 100644 --- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py +++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py @@ -422,7 +422,11 @@ class FnApiRunnerTest(unittest.TestCase): def test_sdf(self): class ExpandingStringsDoFn(beam.DoFn): - def process(self, element, restriction_tracker=ExpandStringsProvider()): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam( + ExpandStringsProvider())): assert isinstance( restriction_tracker, restriction_trackers.OffsetRestrictionTracker), restriction_tracker @@ -442,7 +446,11 @@ class FnApiRunnerTest(unittest.TestCase): counter = beam.metrics.Metrics.counter('ns', 'my_counter') class ExpandStringsDoFn(beam.DoFn): - def process(self, element, restriction_tracker=ExpandStringsProvider()): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam( + ExpandStringsProvider())): assert isinstance( restriction_tracker, restriction_trackers.OffsetRestrictionTracker), restriction_tracker @@ -1271,7 +1279,10 @@ class FnApiRunnerSplitTest(unittest.TestCase): return restriction[1] - restriction[0] class EnumerateSdf(beam.DoFn): - def process(self, element, restriction_tracker=EnumerateProvider()): + def process( + self, + element, + restriction_tracker=beam.DoFn.RestrictionParam(EnumerateProvider())): to_emit = [] for k in range(*restriction_tracker.current_restriction()): if restriction_tracker.try_claim(k): diff --git a/sdks/python/apache_beam/testing/synthetic_pipeline.py b/sdks/python/apache_beam/testing/synthetic_pipeline.py index eb1ec8d..2cace10 100644 --- a/sdks/python/apache_beam/testing/synthetic_pipeline.py +++ b/sdks/python/apache_beam/testing/synthetic_pipeline.py @@ -358,7 +358,8 @@ class SyntheticSDFAsSource(beam.DoFn): def process( self, element, - restriction_tracker=SyntheticSDFSourceRestrictionProvider()): + restriction_tracker=beam.DoFn.RestrictionParam( + SyntheticSDFSourceRestrictionProvider())): for k in range(*restriction_tracker.current_restriction()): if not restriction_tracker.try_claim(k): return diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 5eed185..fb93c00 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -345,6 +345,18 @@ class _DoFnParam(object): return self.param_id +class _RestrictionDoFnParam(_DoFnParam): + """Restriction Provider DoFn parameter.""" + + def __init__(self, restriction_provider): + if not isinstance(restriction_provider, RestrictionProvider): + raise ValueError( + 'DoFn.RestrictionParam expected RestrictionProvider object.') + self.restriction_provider = restriction_provider + self.param_id = ('RestrictionParam(%s)' + % restriction_provider.__class__.__name__) + + class _StateDoFnParam(_DoFnParam): """State DoFn parameter.""" @@ -421,6 +433,8 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): StateParam = _StateDoFnParam TimerParam = _TimerDoFnParam + RestrictionParam = _RestrictionDoFnParam + @staticmethod def from_callable(fn): return CallableWrapperDoFn(fn) @@ -441,8 +455,13 @@ class DoFn(WithTypeHints, HasDisplayData, urns.RunnerApiFn): ``DoFn.SideInputParam``: a side input that may be used when processing. ``DoFn.TimestampParam``: timestamp of the input element. ``DoFn.WindowParam``: ``Window`` the input element belongs to. - A ``RestrictionProvider`` instance: an ``iobase.RestrictionTracker`` will be - provided here to allow treatment as a Splittable `DoFn``. + ``DoFn.TimerParam``: a ``userstate.RuntimeTimer`` object defined by the spec + of the parameter. + ``DoFn.StateParam``: a ``userstate.RuntimeState`` object defined by the spec + of the parameter. + ``DoFn.RestrictionParam``: an ``iobase.RestrictionTracker`` will be + provided here to allow treatment as a Splittable ``DoFn``. The restriction + tracker will be derived from the restriction provider in the parameter. ``DoFn.WatermarkReporterParam``: a function that can be used to report output watermark of Splittable ``DoFn`` implementations. diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index dc9d74b..7564b49 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -178,6 +178,7 @@ ignore_identifiers = [ '_StateDoFnParam', '_TimerDoFnParam', '_BundleFinalizerParam', + '_RestrictionDoFnParam', # Sphinx cannot find this py:class reference target 'typing.Generic',