This is an automated email from the ASF dual-hosted git repository.

lcwik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 31b183d  Add a new sdf E2E test without defer_remainder
     new 596bf10  [BEAM-2939] Add a new sdf E2E test without defer_remainder
31b183d is described below

commit 31b183de500eef9ba97524532c380141021ca160
Author: Boyuan Zhang <boyu...@google.com>
AuthorDate: Mon Apr 15 16:18:09 2019 -0700

    Add a new sdf E2E test without defer_remainder
---
 .../runners/portability/flink_runner_test.py       |  3 ++
 .../runners/portability/fn_api_runner_test.py      | 48 +++++++++++++++-------
 2 files changed, 37 insertions(+), 14 deletions(-)

diff --git a/sdks/python/apache_beam/runners/portability/flink_runner_test.py 
b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
index d67b5fb..ea4ec0a 100644
--- a/sdks/python/apache_beam/runners/portability/flink_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/flink_runner_test.py
@@ -220,6 +220,9 @@ if __name__ == '__main__':
     def test_sdf(self):
       raise unittest.SkipTest("BEAM-2939")
 
+    def test_sdf_with_sdf_initiated_checkpointing(self):
+      raise unittest.SkipTest("BEAM-2939")
+
     def test_callbacks_with_exception(self):
       raise unittest.SkipTest("BEAM-6868")
 
diff --git a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py 
b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
index 7bbc16d..138648d 100644
--- a/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
+++ b/sdks/python/apache_beam/runners/portability/fn_api_runner_test.py
@@ -364,23 +364,25 @@ class FnApiRunnerTest(unittest.TestCase):
 
   def test_sdf(self):
 
-    counter = beam.metrics.Metrics.counter('ns', 'my_counter')
-
-    class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
-      def initial_restriction(self, element):
-        return (0, len(element))
+    class ExpandingStringsDoFn(beam.DoFn):
+      def process(self, element, restriction_tracker=ExpandStringsProvider()):
+        assert isinstance(
+            restriction_tracker,
+            restriction_trackers.OffsetRestrictionTracker), restriction_tracker
+        for k in range(*restriction_tracker.current_restriction()):
+          yield element[k]
 
-      def create_tracker(self, restriction):
-        return restriction_trackers.OffsetRestrictionTracker(
-            restriction[0], restriction[1])
+    with self.create_pipeline() as p:
+      data = ['abc', 'defghijklmno', 'pqrstuv', 'wxyz']
+      actual = (
+          p
+          | beam.Create(data)
+          | beam.ParDo(ExpandingStringsDoFn()))
+      assert_that(actual, equal_to(list(''.join(data))))
 
-      def split(self, element, restriction):
-        start, end = restriction
-        middle = (end - start) // 2
-        return [(start, middle), (middle, end)]
+  def test_sdf_with_sdf_initiated_checkpointing(self):
 
-      def restriction_size(self, element, restriction):
-        return restriction[1] - restriction[0]
+    counter = beam.metrics.Metrics.counter('ns', 'my_counter')
 
     class ExpandStringsDoFn(beam.DoFn):
       def process(self, element, restriction_tracker=ExpandStringsProvider()):
@@ -1167,6 +1169,24 @@ class EventRecorder(object):
     shutil.rmtree(self.tmp_dir)
 
 
+class ExpandStringsProvider(beam.transforms.core.RestrictionProvider):
+  """A RestrictionProvider that used for sdf related tests."""
+  def initial_restriction(self, element):
+    return (0, len(element))
+
+  def create_tracker(self, restriction):
+    return restriction_trackers.OffsetRestrictionTracker(
+        restriction[0], restriction[1])
+
+  def split(self, element, restriction):
+    start, end = restriction
+    middle = (end - start) // 2
+    return [(start, middle), (middle, end)]
+
+  def restriction_size(self, element, restriction):
+    return restriction[1] - restriction[0]
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()

Reply via email to