Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk 062af66f8 -> b75c0c0d4


Adds a test harnesses and utilities framework for sources.

Helper functions and test harnesses for checking correctness of source
(``iobase.BoundedSource``) and range tracker (``iobase.RangeTracker``)
implementations.

Contains a few lightweight utilities (e.g. reading items from a source
such as ``readFromSource()``, as well as heavyweight property testing
and stress testing harnesses that help getting a large amount of test
coverage with few code.

Most notable ones are:
* ``assertSourcesEqualReferenceSource()`` helps testing that the data
read by the union of sources produced by ``BoundedSource.split()`` is
the same as data read by the original source.
* If your source implements dynamic work rebalancing, use the
``assertSplitAtFraction()`` family of functions - they test behavior of
``RangeTracker.try_split()``, in particular, that various consistency
properties are respected and the total set of data read by the source
is preserved when splits happen. Use ``assertSplitAtFractionBehavior()``
to test individual cases of ``splitAtFraction()`` and use
``assertSplitAtFractionExhaustive()`` as a heavy-weight stress test
including concurrency.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/6c2ba81a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/6c2ba81a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/6c2ba81a

Branch: refs/heads/python-sdk
Commit: 6c2ba81aa5a287832ddb75da2b347abfc13bf4e6
Parents: 062af66
Author: Chamikara Jayalath <chamik...@google.com>
Authored: Fri Jul 15 11:07:48 2016 -0700
Committer: Dan Halperin <dhalp...@google.com>
Committed: Tue Jul 19 09:59:18 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio_test.py       |  35 +-
 sdks/python/apache_beam/io/source_test_utils.py | 568 +++++++++++++++++++
 .../apache_beam/io/source_test_utils_test.py    | 117 ++++
 3 files changed, 710 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6c2ba81a/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py 
b/sdks/python/apache_beam/io/avroio_test.py
index d70b3d1..060e211 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -19,9 +19,10 @@ import logging
 import os
 import tempfile
 import unittest
-
 from apache_beam.io import avroio
 from apache_beam.io import filebasedsource
+from apache_beam.io import source_test_utils
+import avro.datafile
 from avro.datafile import DataFileWriter
 from avro.io import DatumWriter
 import avro.schema as avro_schema
@@ -90,16 +91,15 @@ class TestAvro(unittest.TestCase):
       if len(splits) < 2:
         raise ValueError('Test is trivial. Please adjust it so that at least '
                          'two splits get generated')
-      for split in splits:
-        records = [record for record in split.source.read(
-            split.source.get_range_tracker(split.start_position,
-                                           split.stop_position))]
-        read_records.extend(records)
-    else:
-      range_tracker = source.get_range_tracker(None, None)
-      read_records = [record for record in source.read(range_tracker)]
 
-    self.assertItemsEqual(expected_result, read_records)
+      sources_info = [
+          (split.source, split.start_position, split.stop_position)
+          for split in splits]
+      source_test_utils.assertSourcesEqualReferenceSource(
+          (source, None, None), sources_info)
+    else:
+      read_records = source_test_utils.readFromSource(source, None, None)
+      self.assertItemsEqual(expected_result, read_records)
 
   def test_read_without_splitting(self):
     file_name = self._write_data()
@@ -141,6 +141,21 @@ class TestAvro(unittest.TestCase):
     expected_result = self.RECORDS * 3
     self._run_avro_test(pattern, 100, True, expected_result)
 
+  def test_dynamic_work_rebalancing_exhaustive(self):
+    # Adjusting block size so that we can perform a exhaustive dynamic
+    # work rebalancing test that completes within an acceptable amount of time.
+    old_sync_interval = avro.datafile.SYNC_INTERVAL
+    try:
+      avro.datafile.SYNC_INTERVAL = 5
+      file_name = self._write_data(count=20)
+      source = avroio.AvroSource(file_name)
+      splits = [split for split in source.split(
+          desired_bundle_size=float('inf'))]
+      assert len(splits) == 1
+      source_test_utils.assertSplitAtFractionExhaustive(splits[0].source)
+    finally:
+      avro.datafile.SYNC_INTERVAL = old_sync_interval
+
 
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6c2ba81a/sdks/python/apache_beam/io/source_test_utils.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/source_test_utils.py 
b/sdks/python/apache_beam/io/source_test_utils.py
new file mode 100644
index 0000000..1a2c21a
--- /dev/null
+++ b/sdks/python/apache_beam/io/source_test_utils.py
@@ -0,0 +1,568 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Helper functions and test harnesses for source implementations.
+
+This module contains helper functions and test harnesses for checking
+correctness of source (a subclass of ``iobase.BoundedSource``) and range
+tracker (a subclass of``iobase.RangeTracker``) implementations.
+
+Contains a few lightweight utilities (e.g. reading items from a source such as
+``readFromSource()``, as well as heavyweight property testing and stress
+testing harnesses that help getting a large amount of test coverage with few
+code.
+
+Most notable ones are:
+* ``assertSourcesEqualReferenceSource()`` helps testing that the data read by
+the union of sources produced by ``BoundedSource.split()`` is the same as data
+read by the original source.
+* If your source implements dynamic work rebalancing, use the
+``assertSplitAtFraction()`` family of functions - they test behavior of
+``RangeTracker.try_split()``, in particular, that various consistency
+properties are respected and the total set of data read by the source is
+preserved when splits happen. Use ``assertSplitAtFractionBehavior()`` to test
+individual cases of ``RangeTracker.try_split()`` and use
+``assertSplitAtFractionExhaustive()`` as a heavy-weight stress test including
+concurrency. We strongly recommend to use both.
+
+For example usages, see the unit tests of modules such as
+ * apache_beam.io.source_test_utils_test.py
+ * apache_beam.io.avroio_test.py
+"""
+
+from collections import namedtuple
+import logging
+
+from multiprocessing.pool import ThreadPool
+from apache_beam.io import iobase
+
+
+class ExpectedSplitOutcome(object):
+  MUST_SUCCEED_AND_BE_CONSISTENT = 1
+  MUST_FAIL = 2
+  MUST_BE_CONSISTENT_IF_SUCCEEDS = 3
+
+
+SplitAtFractionResult = namedtuple(
+    'SplitAtFractionResult', 'num_primary_items num_residual_items')
+
+SplitFractionStatistics = namedtuple(
+    'SplitFractionStatistics',
+    'successful_fractions non_trivial_fractions')
+
+
+def readFromSource(source, start_position, stop_position):
+  """Reads elements from the given ```BoundedSource```.
+
+  Only reads elements within the given position range.
+  Args:
+    source: ``iobase.BoundedSource`` implementation.
+    start_position: start position for reading.
+    stop_position: stop position for reading.
+
+  Returns:
+    the set of values read from the sources.
+  """
+  values = []
+  range_tracker = source.get_range_tracker(start_position, stop_position)
+  assert isinstance(range_tracker, iobase.RangeTracker)
+  reader = source.read(range_tracker)
+  for value in reader:
+    values.append(value)
+
+  return values
+
+
+def assertSourcesEqualReferenceSource(reference_source_info, sources_info):
+  """Tests if a reference source is equal to a given set of sources.
+
+  Given a reference source (a ``BoundedSource`` and a position range) and a
+  list of sources, assert that the union of the records
+  read from the list of sources is equal to the records read from the
+  reference source.
+
+  Args:
+    reference_source_info: a three-tuple that gives the reference
+                           ``iobase.BoundedSource``, position to start reading
+                           at, and position to stop reading at.
+    sources_info: a set of sources. Each source is a three-tuple that is of
+                  the same format described above.
+
+  Raises:
+    ValueError: if the set of data produced by the reference source and the
+                given set of sources are not equivalent.
+
+  """
+
+  if not (isinstance(reference_source_info, tuple) and
+          len(reference_source_info) == 3 and
+          isinstance(reference_source_info[0], iobase.BoundedSource)):
+    raise ValueError('reference_source_info must a three-tuple where first'
+                     'item of the tuple gives a '
+                     'iobase.BoundedSource. Received: %r'
+                     , reference_source_info)
+  reference_records = readFromSource(
+      *reference_source_info)
+
+  source_records = []
+  for source_info in sources_info:
+    assert isinstance(source_info, tuple)
+    assert len(source_info) == 3
+    if not (isinstance(source_info, tuple) and
+            len(source_info) == 3 and
+            isinstance(source_info[0], iobase.BoundedSource)):
+      raise ValueError('source_info must a three tuple where first'
+                       'item of the tuple gives a '
+                       'iobase.BoundedSource. Received: %r'
+                       , source_info)
+    if (type(reference_source_info[0].default_output_coder()) !=
+        type(source_info[0].default_output_coder())):
+      raise ValueError(
+          'Reference source %r and the source %r must use the same coder. '
+          'They are using %r and %r respectively instead.',
+          reference_source_info[0], source_info[0],
+          type(reference_source_info[0].default_output_coder()),
+          type(source_info[0].default_output_coder()))
+    source_records.extend(readFromSource(*source_info))
+
+  if len(reference_records) != len(source_records):
+    raise ValueError(
+        'Reference source must produce the same number of records as the '
+        'list of sources. Number of records were %d and %d instead.',
+        len(reference_records), len(source_records))
+
+  if sorted(reference_records) != sorted(source_records):
+    raise ValueError(
+        'Reference source and provided list of sources must produce the '
+        'same set of records.')
+
+
+def assertSplitAtFractionBehavior(source, num_items_to_read_before_split,
+                                  split_fraction, expected_outcome):
+  """Verifies the behaviour of splitting a source at a given fraction.
+
+  Asserts that splitting a ``BoundedSource`` either fails after reading
+  ``num_items_to_read_before_split`` items, or succeeds in a way that is
+  consistent according to ``assertSplitAtFractionSucceedsAndConsistent()``.
+
+  Args:
+    source: the source to perform dynamic splitting on.
+    num_items_to_read_before_split: number of items to read before splitting.
+    split_fraction: fraction to split at.
+    expected_outcome: a value from 'ExpectedSplitOutcome'.
+
+  Returns:
+    a tuple that gives the number of items produced by reading the two ranges
+    produced after dynamic splitting. If splitting did not occur, the first
+    value of the tuple will represent the full set of records read by the
+    source while the second value of the tuple will be '-1'.
+  """
+  assert isinstance(source, iobase.BoundedSource)
+  expected_items = readFromSource(source, None, None)
+  return _assertSplitAtFractionBehavior(
+      source, expected_items, num_items_to_read_before_split, split_fraction,
+      expected_outcome)
+
+
+def _assertSplitAtFractionBehavior(
+    source, expected_items, num_items_to_read_before_split,
+    split_fraction, expected_outcome):
+
+  range_tracker = source.get_range_tracker(None, None)
+  assert isinstance(range_tracker, iobase.RangeTracker)
+  current_items = []
+  reader = source.read(range_tracker)
+  # Reading 'num_items_to_read_before_split' items.
+  reader_iter = iter(reader)
+  for _ in range(num_items_to_read_before_split):
+    current_items.append(next(reader_iter))
+
+  suggested_split_position = range_tracker.position_at_fraction(
+      split_fraction)
+
+  stop_position_before_split = range_tracker.stop_position()
+  split_result = range_tracker.try_split(suggested_split_position)
+
+  if split_result is not None:
+    if len(split_result) != 2:
+      raise ValueError('Split result must be a tuple that contains split '
+                       'position and split fraction. Received: %r',
+                       split_result)
+
+    if range_tracker.stop_position() != split_result[0]:
+      raise ValueError('After a successful split, the stop position of the '
+                       'RangeTracker must be the same as the returned split '
+                       'position. Observed %r and %r which are different.',
+                       range_tracker.stop_position(), split_result[0])
+
+    if split_fraction < 0 or split_fraction > 1:
+      raise ValueError('Split fraction must be within the range [0,1]',
+                       'Observed split fraction was %r.', split_result[1])
+
+  stop_position_after_split = range_tracker.stop_position()
+  if split_result and stop_position_after_split == stop_position_before_split:
+    raise ValueError('Stop position %r did not change after a successful '
+                     'split of source %r at fraction %r.',
+                     stop_position_before_split, source, split_fraction)
+
+  if expected_outcome == ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT:
+    if not split_result:
+      raise ValueError('Expected split of source %r at fraction %r to be '
+                       'successful after reading %d elements. But '
+                       'the split failed.',
+                       source, split_fraction,
+                       num_items_to_read_before_split)
+  elif expected_outcome == ExpectedSplitOutcome.MUST_FAIL:
+    if split_result:
+      raise ValueError('Expected split of source %r at fraction %r after '
+                       'reading %d elements to fail. But splitting '
+                       'succeeded with result %r.',
+                       source, split_fraction,
+                       num_items_to_read_before_split, split_result)
+
+  elif (expected_outcome !=
+        ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS):
+    raise ValueError('Unknown type of expected outcome: %r',
+                     expected_outcome)
+  current_items.extend([value for value in reader_iter])
+
+  residual_range = (
+      split_result[0], stop_position_before_split) if split_result else None
+
+  return _verifySingleSplitFractionResult(
+      source, expected_items, current_items,
+      split_result,
+      (range_tracker.start_position(), range_tracker.stop_position()),
+      residual_range, split_fraction)
+
+
+def _range_to_str(start, stop):
+  return '[' + (str(start) + ',' + str(stop) + ')')
+
+
+def _verifySingleSplitFractionResult(
+    source, expected_items, current_items, split_successful, primary_range,
+    residual_range, split_fraction):
+
+  assert primary_range
+  primary_items = readFromSource(source, *primary_range)
+
+  if not split_successful:
+    # For unsuccessful splits, residual_range should be None.
+    assert not residual_range
+
+  residual_items = (
+      readFromSource(source, *residual_range)
+      if split_successful else [])
+
+  total_items = primary_items + residual_items
+
+  if current_items != primary_items:
+    raise ValueError('Current source %r and a source created using the '
+                     'range of the primary source %r determined '
+                     'by performing dynamic work rebalancing at fraction '
+                     '%r produced different values. Expected '
+                     'these sources to produce the same list of values.',
+                     source,
+                     _range_to_str(*primary_range),
+                     split_fraction
+                    )
+
+  if expected_items != total_items:
+    raise ValueError('Items obtained by reading the source %r for primary '
+                     'and residual ranges %s and %s did not produce the '
+                     'expected list of values.',
+                     source,
+                     _range_to_str(*primary_range),
+                     _range_to_str(*residual_range))
+
+  result = (len(primary_items),
+            len(residual_items) if split_successful else -1)
+  return result
+
+
+def assertSplitAtFractionSucceedsAndConsistent(source,
+                                               num_items_to_read_before_split,
+                                               split_fraction):
+  """Verifies some consistency properties of dynamic work rebalancing.
+
+  Equivalent to the following pseudocode:
+
+  original_range_tracker = source.getRangeTracker(None, None)
+  original_reader = source.read(original_range_tracker)
+  items_before_split = read N items from original_reader
+  suggested_split_position = original_range_tracker.position_for_fraction(
+    split_fraction)
+  original_stop_position - original_range_tracker.stop_position()
+  split_result = range_tracker.try_split()
+  split_position, split_fraction = split_result
+  primary_range_tracker = source.get_range_tracker(
+    original_range_tracker.start_position(), split_position)
+  residual_range_tracker = source.get_range_tracker(split_position,
+    original_stop_position)
+
+  assert that: items when reading source.read(primary_range_tracker) ==
+    items_before_split + items from continuing to read 'original_reader'
+  assert that: items when reading source.read(original_range_tracker) =
+    items when reading source.read(primary_range_tracker) + items when reading
+    source.read(residual_range_tracker)
+
+  Args:
+    source: source to perform dynamic work rebalancing on.
+    num_items_to_read_before_split: number of items to read before splitting.
+    split_fraction: fraction to split at.
+  """
+
+  assertSplitAtFractionBehavior(
+      source, num_items_to_read_before_split, split_fraction,
+      ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
+
+
+def assertSplitAtFractionFails(source, num_items_to_read_before_split,
+                               split_fraction):
+  """Asserts that dynamic work rebalancing at a given fraction fails.
+
+  Asserts that trying to perform dynamic splitting after reading
+  'num_items_to_read_before_split' items from the source fails.
+
+  Args:
+    source: source to perform dynamic splitting on.
+    num_items_to_read_before_split: number of items to read before splitting.
+    split_fraction: fraction to split at.
+  """
+
+  assertSplitAtFractionBehavior(
+      source, num_items_to_read_before_split, split_fraction,
+      ExpectedSplitOutcome.MUST_FAIL)
+
+
+def assertSplitAtFractionBinary(source, expected_items,
+                                num_items_to_read_before_split, left_fraction,
+                                left_result,
+                                right_fraction, right_result, stats):
+  """Performs dynamic work rebalancing for fractions within a given range.
+
+  Asserts that given a start position, a source can be split at every
+  interesting fraction (halfway between two fractions that differ by at
+  least one item) and the results are consistent if a split succeeds.
+
+  Args:
+    source: source to perform dynamic splitting on.
+    expected_items: total set of items expected when reading the source.
+    num_items_to_read_before_split: number of items to read before splitting.
+    left_fraction: left fraction for binary splitting.
+    left_result: result received by splitting at left fraction.
+    right_fraction: right fraction for binary splitting.
+    right_result: result received by splitting at right fraction.
+    stats: a ``SplitFractionStatistics`` for storing results.
+  """
+  assert right_fraction > left_fraction
+
+  if right_fraction - left_fraction < 0.001:
+    # This prevents infinite recursion.
+    return
+
+  middle_fraction = (left_fraction + right_fraction) / 2
+
+  if left_result is None:
+    left_result = _assertSplitAtFractionBehavior(
+        source, expected_items, num_items_to_read_before_split, left_fraction,
+        ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
+
+  if right_result is None:
+    right_result = _assertSplitAtFractionBehavior(
+        source, expected_items, num_items_to_read_before_split,
+        right_fraction, ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
+
+  middle_result = _assertSplitAtFractionBehavior(
+      source, expected_items, num_items_to_read_before_split, middle_fraction,
+      ExpectedSplitOutcome.MUST_BE_CONSISTENT_IF_SUCCEEDS)
+
+  if middle_result[1] != -1:
+    stats.successful_fractions.append(middle_fraction)
+  if middle_result[1] > 0:
+    stats.non_trivial_fractions.append(middle_fraction)
+
+  # Two split results are equivalent if primary and residual ranges of them
+  # produce the same number of records (simply checking the size of primary
+  # enough since the total number of records is constant).
+
+  if left_result[0] != middle_result[0]:
+    assertSplitAtFractionBinary(
+        source, expected_items, num_items_to_read_before_split, left_fraction,
+        left_result, middle_fraction, middle_result, stats)
+
+  # We special case right_fraction=1.0 since that could fail due to being out
+  # of range. (even if a dynamic split fails at 'middle_fraction' and at
+  # fraction 1.0, there might be fractions in range ('middle_fraction', 1.0)
+  # where dynamic splitting succeeds).
+  if right_fraction == 1.0 or middle_result[0] != right_result[0]:
+    assertSplitAtFractionBinary(
+        source, expected_items, num_items_to_read_before_split,
+        middle_fraction, middle_result, right_fraction, right_result, stats)
+
+MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM = 100
+MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL = 1000
+
+
+def assertSplitAtFractionExhaustive(source, perform_multi_threaded_test=True):
+  """Performs and tests dynamic work rebalancing exhaustively.
+
+  Asserts that for each possible start position, a source can be split at
+  every interesting fraction (halfway between two fractions that differ by at
+  least one item) and the results are consistent if a split succeeds.
+  Verifies multi threaded splitting as well.
+
+  Args:
+    source: the source to perform dynamic splitting on.
+    perform_multi_threaded_test: if true performs a multi-threaded test
+                                 otherwise this test is skipped.
+
+  Raises:
+    ValueError: if the exhaustive splitting test fails.
+  """
+
+  expected_items = readFromSource(source, None, None)
+  if not expected_items:
+    raise ValueError('Source %r is empty.', source)
+
+  if len(expected_items) == 1:
+    raise ValueError('Source %r only reads a single item.', source)
+
+  all_non_trivial_fractions = []
+
+  any_successful_fractions = False
+  any_non_trivial_fractions = False
+
+  for i in range(len(expected_items)):
+    stats = SplitFractionStatistics([], [])
+
+    assertSplitAtFractionBinary(
+        source, expected_items, i, 0.0, None, 1.0, None, stats)
+
+    if stats.successful_fractions:
+      any_successful_fractions = True
+    if stats.non_trivial_fractions:
+      any_non_trivial_fractions = True
+
+    all_non_trivial_fractions.append(stats.non_trivial_fractions)
+
+  if not any_successful_fractions:
+    raise ValueError('SplitAtFraction test completed vacuously: no '
+                     'successful split fractions found')
+
+  if not any_non_trivial_fractions:
+    raise ValueError(
+        'SplitAtFraction test completed vacuously: no non-trivial split '
+        'fractions found')
+
+  if not perform_multi_threaded_test:
+    return
+
+  num_total_trials = 0
+  for i in range(len(expected_items)):
+    non_trivial_fractions = [2.0]  # 2.0 is larger than any valid fraction.
+    non_trivial_fractions.extend(all_non_trivial_fractions[i])
+    min_non_trivial_fraction = min(non_trivial_fractions)
+
+    if min_non_trivial_fraction == 2.0:
+      # This will not happen all the time. Otherwise previous test will fail
+      # due to vacuousness.
+      continue
+
+    num_trials = 0
+    have_success = False
+    have_failure = False
+
+    while True:
+      num_trials += 1
+      if (num_trials >
+          MAX_CONCURRENT_SPLITTING_TRIALS_PER_ITEM):
+        logging.warn(
+            'After %d concurrent splitting trials at item #%d, observed '
+            'only %s, giving up on this item',
+            num_trials,
+            i,
+            'success' if have_success else 'failure'
+        )
+        break
+
+      if _assertSplitAtFractionConcurrent(
+          source, expected_items, i, min_non_trivial_fraction):
+        have_success = True
+      else:
+        have_failure = True
+
+      if have_success and have_failure:
+        logging.info('%d trials to observe both success and failure of '
+                     'concurrent splitting at item #%d', num_trials, i)
+        break
+
+    num_total_trials += num_trials
+
+    if num_total_trials > MAX_CONCURRENT_SPLITTING_TRIALS_TOTAL:
+      logging.warn('After %d total concurrent splitting trials, considered '
+                   'only %d items, giving up.', num_total_trials, i)
+      break
+
+  logging.info('%d total concurrent splitting trials for %d items',
+               num_total_trials, len(expected_items))
+
+
+def _assertSplitAtFractionConcurrent(
+    source, expected_items, num_items_to_read_before_splitting,
+    split_fraction):
+
+  range_tracker = source.get_range_tracker(None, None)
+  stop_position_before_split = range_tracker.stop_position()
+  reader = source.read(range_tracker)
+  reader_iter = iter(reader)
+
+  current_items = []
+  for _ in range(num_items_to_read_before_splitting):
+    current_items.append(next(reader_iter))
+
+  def read_or_split(test_params):
+    if test_params[0]:
+      return [val for val in test_params[1]]
+    else:
+      position = test_params[1].position_at_fraction(test_params[2])
+      result = test_params[1].try_split(position)
+      return result
+
+  inputs = []
+
+  inputs.append([True, reader_iter])
+  inputs.append([False, range_tracker, split_fraction])
+
+  pool = ThreadPool(2)
+  results = pool.map(read_or_split, inputs)
+  pool.close()
+
+  current_items.extend(results[0])
+  primary_range = (
+      range_tracker.start_position(), range_tracker.stop_position())
+
+  split_result = results[1]
+  residual_range = (
+      split_result[0], stop_position_before_split) if split_result else None
+
+  res = _verifySingleSplitFractionResult(
+      source, expected_items, current_items, split_result,
+      primary_range, residual_range, split_fraction)
+
+  return res[1] > 0

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/6c2ba81a/sdks/python/apache_beam/io/source_test_utils_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/source_test_utils_test.py 
b/sdks/python/apache_beam/io/source_test_utils_test.py
new file mode 100644
index 0000000..3aff07f
--- /dev/null
+++ b/sdks/python/apache_beam/io/source_test_utils_test.py
@@ -0,0 +1,117 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import tempfile
+import unittest
+from filebasedsource_test import LineSource
+import source_test_utils
+
+
+class SourceTestUtilsTest(unittest.TestCase):
+
+  def _create_file_with_data(self, lines):
+    assert isinstance(lines, list)
+    with tempfile.NamedTemporaryFile(delete=False) as f:
+      for line in lines:
+        f.write(line + '\n')
+
+      return f.name
+
+  def _create_data(self, num_lines):
+    return ['line ' + str(i) for i in range(num_lines)]
+
+  def _create_source(self, data):
+    source = LineSource(self._create_file_with_data(data))
+    # By performing initial splitting, we can get a source for a single file.
+    # This source, that uses OffsetRangeTracker, is better for testing 
purposes,
+    # than using the original source for a file-pattern.
+    for bundle in source.split(float('inf')):
+      return bundle.source
+
+  def test_read_from_source(self):
+    data = self._create_data(100)
+    source = self._create_source(data)
+    self.assertItemsEqual(
+        data, source_test_utils.readFromSource(source, None, None))
+
+  def test_source_equals_reference_source(self):
+    data = self._create_data(100)
+    reference_source = self._create_source(data)
+    sources_info = [(split.source, split.start_position, split.stop_position)
+                    for split in 
reference_source.split(desired_bundle_size=50)]
+    if len(sources_info) < 2:
+      raise ValueError('Test is too trivial since splitting only generated %d'
+                       'bundles. Please adjust the test so that at least '
+                       'two splits get generated.', len(sources_info))
+
+    source_test_utils.assertSourcesEqualReferenceSource(
+        (reference_source, None, None), sources_info)
+
+  def test_split_at_fraction_successful(self):
+    data = self._create_data(100)
+    source = self._create_source(data)
+    result1 = source_test_utils.assertSplitAtFractionBehavior(
+        source, 10, 0.5,
+        source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
+    result2 = source_test_utils.assertSplitAtFractionBehavior(
+        source, 20, 0.5,
+        source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
+    self.assertEquals(result1, result2)
+    self.assertEquals(100, result1[0] + result1[1])
+
+    result3 = source_test_utils.assertSplitAtFractionBehavior(
+        source, 30, 0.8,
+        source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
+    result4 = source_test_utils.assertSplitAtFractionBehavior(
+        source, 50, 0.8,
+        source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
+    self.assertEquals(result3, result4)
+    self.assertEquals(100, result3[0] + result4[1])
+
+    self.assertTrue(result1[0] < result3[0])
+    self.assertTrue(result1[1] > result3[1])
+
+  def test_split_at_fraction_fails(self):
+    data = self._create_data(100)
+    source = self._create_source(data)
+
+    result = source_test_utils.assertSplitAtFractionBehavior(
+        source, 90, 0.1, source_test_utils.ExpectedSplitOutcome.MUST_FAIL)
+    self.assertEquals(result[0], 100)
+    self.assertEquals(result[1], -1)
+
+    with self.assertRaises(ValueError):
+      source_test_utils.assertSplitAtFractionBehavior(
+          source, 10, 0.5, source_test_utils.ExpectedSplitOutcome.MUST_FAIL)
+
+  def test_split_at_fraction_binary(self):
+    data = self._create_data(100)
+    source = self._create_source(data)
+
+    stats = source_test_utils.SplitFractionStatistics([], [])
+    source_test_utils.assertSplitAtFractionBinary(
+        source, data, 10, 0.5, None, 0.8, None, stats)
+
+    # These lists should not be empty now.
+    self.assertTrue(stats.successful_fractions)
+    self.assertTrue(stats.non_trivial_fractions)
+
+  def test_split_at_fraction_exhaustive(self):
+    data = self._create_data(20)
+    source = self._create_source(data)
+
+    source_test_utils.assertSplitAtFractionExhaustive(source)

Reply via email to