[BEAM-1768] Fix assert_that for empty inputs

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c39b02d9
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c39b02d9
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c39b02d9

Branch: refs/heads/master
Commit: c39b02d9ea2d343f348beca1228776204605fb93
Parents: 030528f
Author: Robert Bradshaw <rober...@gmail.com>
Authored: Mon Mar 20 18:20:01 2017 -0700
Committer: Robert Bradshaw <rober...@google.com>
Committed: Tue Mar 21 10:38:37 2017 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/transforms/util.py      | 18 ++++---
 sdks/python/apache_beam/transforms/util_test.py | 50 ++++++++++++++++++++
 2 files changed, 62 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/c39b02d9/sdks/python/apache_beam/transforms/util.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/util.py 
b/sdks/python/apache_beam/transforms/util.py
index e3f5b85..ac7eb3c 100644
--- a/sdks/python/apache_beam/transforms/util.py
+++ b/sdks/python/apache_beam/transforms/util.py
@@ -22,6 +22,7 @@ from __future__ import absolute_import
 
 from apache_beam.transforms import window
 from apache_beam.transforms.core import CombinePerKey
+from apache_beam.transforms.core import Create
 from apache_beam.transforms.core import Flatten
 from apache_beam.transforms.core import GroupByKey
 from apache_beam.transforms.core import Map
@@ -222,12 +223,17 @@ def assert_that(actual, matcher, label='assert_that'):
   class AssertThat(PTransform):
 
     def expand(self, pcoll):
-      return (pcoll
-              | WindowInto(window.GlobalWindows())
-              | "ToVoidKey" >> Map(lambda v: (None, v))
-              | "Group" >> GroupByKey()
-              | "UnKey" >> Map(lambda (k, v): v)
-              | "Match" >> Map(matcher))
+      # We must have at least a single element to ensure the matcher
+      # code gets run even if the input pcollection is empty.
+      keyed_singleton = pcoll.pipeline | Create([(None, None)])
+      keyed_actual = (
+          pcoll
+          | WindowInto(window.GlobalWindows())
+          | "ToVoidKey" >> Map(lambda v: (None, v)))
+      _ = ((keyed_singleton, keyed_actual)
+           | "Group" >> CoGroupByKey()
+           | "Unkey" >> Map(lambda (k, (_, actual_values)): actual_values)
+           | "Match" >> Map(matcher))
 
     def default_label(self):
       return label

http://git-wip-us.apache.org/repos/asf/beam/blob/c39b02d9/sdks/python/apache_beam/transforms/util_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/transforms/util_test.py 
b/sdks/python/apache_beam/transforms/util_test.py
new file mode 100644
index 0000000..9656827
--- /dev/null
+++ b/sdks/python/apache_beam/transforms/util_test.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Unit tests for the util transforms."""
+
+import unittest
+
+from apache_beam import Create
+from apache_beam.test_pipeline import TestPipeline
+from apache_beam.transforms.util import assert_that, equal_to, is_empty
+
+
+class UtilTest(unittest.TestCase):
+
+  def test_assert_that_passes(self):
+    with TestPipeline() as p:
+      assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3]))
+
+  def test_assert_that_fails(self):
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([1, 10, 100]), equal_to([1, 2, 3]))
+
+  def test_assert_that_fails_on_empty_input(self):
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([]), equal_to([1, 2, 3]))
+
+  def test_assert_that_fails_on_empty_expected(self):
+    with self.assertRaises(Exception):
+      with TestPipeline() as p:
+        assert_that(p | Create([1, 2, 3]), is_empty())
+
+
+if __name__ == '__main__':
+  unittest.main()

Reply via email to