[BEAM-1768] Fix assert_that for empty inputs
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/c39b02d9 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/c39b02d9 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/c39b02d9 Branch: refs/heads/master Commit: c39b02d9ea2d343f348beca1228776204605fb93 Parents: 030528f Author: Robert Bradshaw <rober...@gmail.com> Authored: Mon Mar 20 18:20:01 2017 -0700 Committer: Robert Bradshaw <rober...@google.com> Committed: Tue Mar 21 10:38:37 2017 -0700 ---------------------------------------------------------------------- sdks/python/apache_beam/transforms/util.py | 18 ++++--- sdks/python/apache_beam/transforms/util_test.py | 50 ++++++++++++++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/c39b02d9/sdks/python/apache_beam/transforms/util.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index e3f5b85..ac7eb3c 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -22,6 +22,7 @@ from __future__ import absolute_import from apache_beam.transforms import window from apache_beam.transforms.core import CombinePerKey +from apache_beam.transforms.core import Create from apache_beam.transforms.core import Flatten from apache_beam.transforms.core import GroupByKey from apache_beam.transforms.core import Map @@ -222,12 +223,17 @@ def assert_that(actual, matcher, label='assert_that'): class AssertThat(PTransform): def expand(self, pcoll): - return (pcoll - | WindowInto(window.GlobalWindows()) - | "ToVoidKey" >> Map(lambda v: (None, v)) - | "Group" >> GroupByKey() - | "UnKey" >> Map(lambda (k, v): v) - | "Match" >> Map(matcher)) + # We must have at least a single element to ensure the matcher + # code gets run even if the input pcollection is empty. + keyed_singleton = pcoll.pipeline | Create([(None, None)]) + keyed_actual = ( + pcoll + | WindowInto(window.GlobalWindows()) + | "ToVoidKey" >> Map(lambda v: (None, v))) + _ = ((keyed_singleton, keyed_actual) + | "Group" >> CoGroupByKey() + | "Unkey" >> Map(lambda (k, (_, actual_values)): actual_values) + | "Match" >> Map(matcher)) def default_label(self): return label http://git-wip-us.apache.org/repos/asf/beam/blob/c39b02d9/sdks/python/apache_beam/transforms/util_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py new file mode 100644 index 0000000..9656827 --- /dev/null +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the util transforms.""" + +import unittest + +from apache_beam import Create +from apache_beam.test_pipeline import TestPipeline +from apache_beam.transforms.util import assert_that, equal_to, is_empty + + +class UtilTest(unittest.TestCase): + + def test_assert_that_passes(self): + with TestPipeline() as p: + assert_that(p | Create([1, 2, 3]), equal_to([1, 2, 3])) + + def test_assert_that_fails(self): + with self.assertRaises(Exception): + with TestPipeline() as p: + assert_that(p | Create([1, 10, 100]), equal_to([1, 2, 3])) + + def test_assert_that_fails_on_empty_input(self): + with self.assertRaises(Exception): + with TestPipeline() as p: + assert_that(p | Create([]), equal_to([1, 2, 3])) + + def test_assert_that_fails_on_empty_expected(self): + with self.assertRaises(Exception): + with TestPipeline() as p: + assert_that(p | Create([1, 2, 3]), is_empty()) + + +if __name__ == '__main__': + unittest.main()