Repository: beam Updated Branches: refs/heads/master de9e8528c -> 8f5f19d11
Updates Python SDK source API so that sources can report limited parallelism signals. With this update Python BoundedSource/RangeTracker API can report consumed and remaining number of split points while performing a source read operations, similar to Java SDK sources. These signals can be used by runner implementations, for example, to perform scaling decisions. Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/97a7ae44 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/97a7ae44 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/97a7ae44 Branch: refs/heads/master Commit: 97a7ae449e1ccf6b08a0ee0bc2fc0a1b49924f1f Parents: de9e852 Author: Chamikara Jayalath <chamik...@google.com> Authored: Wed Jan 4 19:10:09 2017 -0800 Committer: Robert Bradshaw <rober...@gmail.com> Committed: Thu Mar 2 17:19:08 2017 -0800 ---------------------------------------------------------------------- sdks/python/apache_beam/io/avroio.py | 33 +++- sdks/python/apache_beam/io/avroio_test.py | 31 ++++ sdks/python/apache_beam/io/iobase.py | 157 ++++++++++++++++++- sdks/python/apache_beam/io/range_trackers.py | 41 ++++- .../apache_beam/io/range_trackers_test.py | 52 ++++++ sdks/python/apache_beam/io/textio.py | 17 +- sdks/python/apache_beam/io/textio_test.py | 13 ++ .../runners/dataflow/native_io/iobase.py | 13 +- 8 files changed, 342 insertions(+), 15 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/avroio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio.py b/sdks/python/apache_beam/io/avroio.py index 5dab651..ab98530 100644 --- a/sdks/python/apache_beam/io/avroio.py +++ b/sdks/python/apache_beam/io/avroio.py @@ -28,6 +28,7 @@ from avro import schema import apache_beam as beam from apache_beam.io import filebasedsource from apache_beam.io import fileio +from apache_beam.io import iobase from apache_beam.io.iobase import Read from apache_beam.transforms import PTransform @@ -135,6 +136,7 @@ class _AvroUtils(object): ValueError: If the block cannot be read properly because the file doesn't match the specification. """ + offset = f.tell() decoder = avroio.BinaryDecoder(f) num_records = decoder.read_long() block_size = decoder.read_long() @@ -144,7 +146,8 @@ class _AvroUtils(object): raise ValueError('Unexpected sync marker (actual "%s" vs expected "%s"). ' 'Maybe the underlying avro file is corrupted?', sync_marker, expected_sync_marker) - return _AvroBlock(block_bytes, num_records, codec, schema) + size = f.tell() - offset + return _AvroBlock(block_bytes, num_records, codec, schema, offset, size) @staticmethod def advance_file_past_next_sync_marker(f, sync_marker): @@ -172,13 +175,22 @@ class _AvroUtils(object): class _AvroBlock(object): """Represents a block of an Avro file.""" - def __init__(self, block_bytes, num_records, codec, schema_string): + def __init__(self, block_bytes, num_records, codec, schema_string, + offset, size): # Decompress data early on (if needed) and thus decrease the number of # parallel copies of the data in memory at any given in time during # block iteration. self._decompressed_block_bytes = self._decompress_bytes(block_bytes, codec) self._num_records = num_records self._schema = schema.parse(schema_string) + self._offset = offset + self._size = size + + def size(self): + return self._size + + def offset(self): + return self._offset @staticmethod def _decompress_bytes(data, codec): @@ -232,12 +244,26 @@ class _AvroSource(filebasedsource.FileBasedSource): """ def read_records(self, file_name, range_tracker): + next_block_start = -1 + + def split_points_unclaimed(stop_position): + if next_block_start >= stop_position: + # Next block starts at or after the suggested stop position. Hence + # there will not be split points to be claimed for the range ending at + # suggested stop position. + return 0 + + return iobase.RangeTracker.SPLIT_POINTS_UNKNOWN + + range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed) + start_offset = range_tracker.start_position() if start_offset is None: start_offset = 0 with self.open_file(file_name) as f: - codec, schema_string, sync_marker = _AvroUtils.read_meta_data_from_file(f) + codec, schema_string, sync_marker = _AvroUtils.read_meta_data_from_file( + f) # We have to start at current position if previous bundle ended at the # end of a sync marker. @@ -248,6 +274,7 @@ class _AvroSource(filebasedsource.FileBasedSource): while range_tracker.try_claim(f.tell()): block = _AvroUtils.read_block_from_file(f, codec, schema_string, sync_marker) + next_block_start = block.offset() + block.size() for record in block.records(): yield record http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/avroio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/avroio_test.py b/sdks/python/apache_beam/io/avroio_test.py index 233ab69..8b14443 100644 --- a/sdks/python/apache_beam/io/avroio_test.py +++ b/sdks/python/apache_beam/io/avroio_test.py @@ -22,6 +22,7 @@ import tempfile import unittest import apache_beam as beam +from apache_beam.io import iobase from apache_beam.io import avroio from apache_beam.io import filebasedsource from apache_beam.io import source_test_utils @@ -248,6 +249,36 @@ class TestAvro(unittest.TestCase): expected_result = self.RECORDS * 2000 self._run_avro_test(file_name, 10000, True, expected_result) + def test_split_points(self): + file_name = self._write_data(count=12000) + source = AvroSource(file_name) + + splits = [ + split + for split in source.split(desired_bundle_size=float('inf')) + ] + assert len(splits) == 1 + + range_tracker = splits[0].source.get_range_tracker( + splits[0].start_position, splits[0].stop_position) + + split_points_report = [] + + for _ in splits[0].source.read(range_tracker): + split_points_report.append(range_tracker.split_points()) + + # There are a total of three blocks. Each block has more than 10 records. + + # When reading records of the first block, range_tracker.split_points() + # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + self.assertEquals( + split_points_report[:10], + [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) + + # When reading records of last block, range_tracker.split_points() should + # return (2, 1) + self.assertEquals(split_points_report[-10:], [(2, 1)] * 10) + def test_read_without_splitting_compressed_deflate(self): file_name = self._write_data(codec='deflate') expected_result = self.RECORDS http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/iobase.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/iobase.py b/sdks/python/apache_beam/io/iobase.py index e139a24..bd40a3e 100644 --- a/sdks/python/apache_beam/io/iobase.py +++ b/sdks/python/apache_beam/io/iobase.py @@ -300,6 +300,8 @@ class RangeTracker(object): the current reader and by a reader of the task starting at 43). """ + SPLIT_POINTS_UNKNOWN = object() + def start_position(self): """Returns the starting position of the current range, inclusive.""" raise NotImplementedError(type(self)) @@ -317,8 +319,8 @@ class RangeTracker(object): ** Thread safety ** - This method along with several other methods of this class may be invoked by - multiple threads, hence must be made thread-safe, e.g. by using a single + Methods of the class ``RangeTracker`` including this method may get invoked + by different threads, hence must be made thread-safe, e.g. by using a single lock object. Args: @@ -352,8 +354,8 @@ class RangeTracker(object): ** Thread safety ** - This method along with several other methods of this class may be invoked by - multiple threads, hence must be made thread-safe, e.g. by using a single + Methods of the class ``RangeTracker`` including this method may get invoked + by different threads, hence must be made thread-safe, e.g. by using a single lock object. Args: @@ -387,8 +389,8 @@ class RangeTracker(object): ** Thread safety ** - This method along with several other methods of this class may be invoked by - multiple threads, hence must be made thread-safe, e.g. by using a single + Methods of the class ``RangeTracker`` including this method may get invoked + by different threads, hence must be made thread-safe, e.g. by using a single lock object. Args: @@ -405,8 +407,8 @@ class RangeTracker(object): ** Thread safety ** - This method along with several other methods of this class may be invoked by - multiple threads, hence must be made thread-safe, e.g. by using a single + Methods of the class ``RangeTracker`` including this method may get invoked + by different threads, hence must be made thread-safe, e.g. by using a single lock object. Returns: @@ -416,6 +418,145 @@ class RangeTracker(object): """ raise NotImplementedError + def split_points(self): + """Gives the number of split points consumed and remaining. + + For a ``RangeTracker`` used by a ``BoundedSource`` (within a + ``BoundedSource.read()`` invocation) this method produces a 2-tuple that + gives the number of split points consumed by the ``BoundedSource`` and the + number of split points remaining within the range of the ``RangeTracker`` + that has not been consumed by the ``BoundedSource``. + + More specifically, given that the position of the current record being read + by ``BoundedSource`` is current_position this method produces a tuple that + consists of + (1) number of split points in the range [self.start_position(), + current_position) without including the split point that is currently being + consumed. This represents the total amount of parallelism in the consumed + part of the source. + (2) number of split points within the range + [current_position, self.stop_position()) including the split point that is + currently being consumed. This represents the total amount of parallelism in + the unconsumed part of the source. + + Methods of the class ``RangeTracker`` including this method may get invoked + by different threads, hence must be made thread-safe, e.g. by using a single + lock object. + + ** General information about consumed and remaining number of split + points returned by this method. ** + + * Before a source read (``BoundedSource.read()`` invocation) claims the + first split point, number of consumed split points is 0. This condition + holds independent of whether the input is "splittable". A splittable + source is a source that has more than one split point. + * Any source read that has only claimed one split point has 0 consumed + split points since the first split point is the current split point and + is still being processed. This condition holds independent of whether + the input is splittable. + * For an empty source read which never invokes + ``RangeTracker.try_claim()``, the consumed number of split points is 0. + This condition holds independent of whether the input is splittable. + * For a source read which has invoked ``RangeTracker.try_claim()`` n + times, the consumed number of split points is n -1. + * If a ``BoundedSource`` sets a callback through function + ``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use that + callback when determining remaining number of split points. + * Remaining split points should include the split point that is currently + being consumed by the source read. Hence if the above callback returns + an integer value n, remaining number of split points should be (n + 1). + * After last split point is claimed remaining split points becomes 1, + because this unfinished read itself represents an unfinished split + point. + * After all records of the source has been consumed, remaining number of + split points becomes 0 and consumed number of split points becomes equal + to the total number of split points within the range being read by the + source. This method does not address this condition and will continue to + report number of consumed split points as + ("total number of split points" - 1) and number of remaining split + points as 1. A runner that performs the reading of the source can + detect when all records have been consumed and adjust remaining and + consumed number of split points accordingly. + + ** Examples ** + + (1) A "perfectly splittable" input which can be read in parallel down to the + individual records. + + Consider a perfectly splittable input that consists of 50 split points. + + * Before a source read (``BoundedSource.read()`` invocation) claims the + first split point, number of consumed split points is 0 number of + remaining split points is 50. + * After claiming first split point, consumed number of split points is 0 + and remaining number of split is 50. + * After claiming split point #30, consumed number of split points is 29 + and remaining number of split points is 21. + * After claiming all 50 split points, consumed number of split points is + 49 and remaining number of split points is 1. + + (2) a "block-compressed" file format such as ``avroio``, in which a block of + records has to be read as a whole, but different blocks can be read in + parallel. + + Consider a block compressed input that consists of 5 blocks. + + * Before a source read (``BoundedSource.read()`` invocation) claims the + first split point (first block), number of consumed split points is 0 + number of remaining split points is 5. + * After claiming first split point, consumed number of split points is 0 + and remaining number of split is 5. + * After claiming split point #3, consumed number of split points is 2 + and remaining number of split points is 3. + * After claiming all 5 split points, consumed number of split points is + 4 and remaining number of split points is 1. + + (3) an "unsplittable" input such as a cursor in a database or a gzip + compressed file. + + Such an input is considered to have only a single split point. Number of + consumed split points is always 0 and number of remaining split points + is always 1. + + By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` for + both consumed and remaining number of split points, which indicates that the + number of split points consumed and remaining is unknown. + + Returns: + A pair that gives consumed and remaining number of split points. Consumed + number of split points should be an integer larger than or equal to zero + or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split points + should be an integer larger than zero or + ``RangeTracker.SPLIT_POINTS_UNKNOWN``. + """ + return (RangeTracker.SPLIT_POINTS_UNKNOWN, + RangeTracker.SPLIT_POINTS_UNKNOWN) + + def set_split_points_unclaimed_callback(self, callback): + """Sets a callback for determining the unclaimed number of split points. + + By invoking this function, a ``BoundedSource`` can set a callback function + that may get invoked by the ``RangeTracker`` to determine the number of + unclaimed split points. A split point is unclaimed if + ``RangeTracker.try_claim()`` method has not been successfully invoked for + that particular split point. The callback function accepts a single + parameter, a stop position for the BoundedSource (stop_position). If the + record currently being consumed by the ``BoundedSource`` is at position + current_position, callback should return the number of split points within + the range (current_position, stop_position). Note that, this should not + include the split point that is currently being consumed by the source. + + This function must be implemented by subclasses before being used. + + Args: + callback: a function that takes a single parameter, a stop position, + and returns unclaimed number of split points for the source read + operation that is calling this function. Value returned from + callback should be either an integer larger than or equal to + zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. + """ + raise NotImplementedError + class Sink(HasDisplayData): """A resource that can be written to using the ``df.io.Write`` transform. http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/range_trackers.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers.py b/sdks/python/apache_beam/io/range_trackers.py index 4c8f7eb..059b6ca 100644 --- a/sdks/python/apache_beam/io/range_trackers.py +++ b/sdks/python/apache_beam/io/range_trackers.py @@ -55,6 +55,9 @@ class OffsetRangeTracker(iobase.RangeTracker): self._offset_of_last_split_point = -1 self._lock = threading.Lock() + self._split_points_seen = 0 + self._split_points_unclaimed_callback = None + def start_position(self): return self._start_offset @@ -106,6 +109,7 @@ class OffsetRangeTracker(iobase.RangeTracker): return False self._offset_of_last_split_point = record_start self._last_record_start = record_start + self._split_points_seen += 1 return True def set_current_position(self, record_start): @@ -167,6 +171,24 @@ class OffsetRangeTracker(iobase.RangeTracker): return int(math.ceil(self.start_position() + fraction * ( self.stop_position() - self.start_position()))) + def split_points(self): + with self._lock: + split_points_consumed = ( + 0 if self._split_points_seen == 0 else self._split_points_seen - 1) + split_points_unclaimed = ( + self._split_points_unclaimed_callback(self.stop_position()) + if self._split_points_unclaimed_callback + else iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + split_points_remaining = ( + iobase.RangeTracker.SPLIT_POINTS_UNKNOWN if + split_points_unclaimed == iobase.RangeTracker.SPLIT_POINTS_UNKNOWN + else (split_points_unclaimed + 1)) + + return (split_points_consumed, split_points_remaining) + + def set_split_points_unclaimed_callback(self, callback): + self._split_points_unclaimed_callback = callback + class GroupedShuffleRangeTracker(iobase.RangeTracker): """A 'RangeTracker' for positions used by'GroupedShuffleReader'. @@ -184,6 +206,7 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker): self._decoded_stop_pos = decoded_stop_pos self._decoded_last_group_start = None self._last_group_was_at_a_split_point = False + self._split_points_seen = 0 self._lock = threading.Lock() def start_position(self): @@ -240,6 +263,7 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker): self._decoded_last_group_start = decoded_group_start self._last_group_was_at_a_split_point = True + self._split_points_seen += 1 return True def set_current_position(self, decoded_group_start): @@ -285,6 +309,14 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker): ' consumed due to positions being opaque strings' ' that are interpreted by the service') + def split_points(self): + with self._lock: + splits_points_consumed = ( + 0 if self._split_points_seen <= 1 else (self._split_points_seen - 1)) + + return (splits_points_consumed, + iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + class OrderedPositionRangeTracker(iobase.RangeTracker): """ @@ -380,7 +412,7 @@ class UnsplittableRangeTracker(iobase.RangeTracker): range_tracker: a ``RangeTracker`` to which all method calls expect calls to ``try_split()`` will be delegated. """ - assert range_tracker + assert isinstance(range_tracker, iobase.RangeTracker) self._range_tracker = range_tracker def start_position(self): @@ -404,6 +436,13 @@ class UnsplittableRangeTracker(iobase.RangeTracker): def fraction_consumed(self): return self._range_tracker.fraction_consumed() + def split_points(self): + # An unsplittable range only contains a single split point. + return (0, 1) + + def set_split_points_unclaimed_callback(self, callback): + self._range_tracker.set_split_points_unclaimed_callback(callback) + class LexicographicKeyRangeTracker(OrderedPositionRangeTracker): """ http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/range_trackers_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/range_trackers_test.py b/sdks/python/apache_beam/io/range_trackers_test.py index b80d1f3..edb6386 100644 --- a/sdks/python/apache_beam/io/range_trackers_test.py +++ b/sdks/python/apache_beam/io/range_trackers_test.py @@ -24,6 +24,7 @@ import math import unittest +from apache_beam.io import iobase from apache_beam.io import range_trackers @@ -158,6 +159,35 @@ class OffsetRangeTrackerTest(unittest.TestCase): with self.assertRaises(Exception): tracker.try_claim(110) + def test_try_split_points(self): + tracker = range_trackers.OffsetRangeTracker(100, 400) + + def dummy_callback(stop_position): + return int(stop_position / 5) + + tracker.set_split_points_unclaimed_callback(dummy_callback) + + self.assertEqual(tracker.split_points(), + (0, 81)) + self.assertTrue(tracker.try_claim(120)) + self.assertEqual(tracker.split_points(), + (0, 81)) + self.assertTrue(tracker.try_claim(140)) + self.assertEqual(tracker.split_points(), + (1, 81)) + tracker.try_split(200) + self.assertEqual(tracker.split_points(), + (1, 41)) + self.assertTrue(tracker.try_claim(150)) + self.assertEqual(tracker.split_points(), + (2, 41)) + self.assertTrue(tracker.try_claim(180)) + self.assertEqual(tracker.split_points(), + (3, 41)) + self.assertFalse(tracker.try_claim(210)) + self.assertEqual(tracker.split_points(), + (3, 41)) + class GroupedShuffleRangeTrackerTest(unittest.TestCase): @@ -319,6 +349,28 @@ class GroupedShuffleRangeTrackerTest(unittest.TestCase): self.assertFalse(tracker.try_claim( self.bytes_to_position([3, 2, 1]))) + def test_split_points(self): + tracker = range_trackers.GroupedShuffleRangeTracker( + self.bytes_to_position([1, 0, 0]), + self.bytes_to_position([5, 0, 0])) + self.assertEqual(tracker.split_points(), + (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 3]))) + self.assertEqual(tracker.split_points(), + (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 5]))) + self.assertEqual(tracker.split_points(), + (1, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 6, 8]))) + self.assertEqual(tracker.split_points(), + (2, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertTrue(tracker.try_claim(self.bytes_to_position([4, 255, 255]))) + self.assertEqual(tracker.split_points(), + (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + self.assertFalse(tracker.try_claim(self.bytes_to_position([5, 1, 0]))) + self.assertEqual(tracker.split_points(), + (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)) + class OrderedPositionRangeTrackerTest(unittest.TestCase): http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/textio.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 19980cb..5bb1a9d 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -24,6 +24,7 @@ import logging from apache_beam import coders from apache_beam.io import filebasedsource from apache_beam.io import fileio +from apache_beam.io import iobase from apache_beam.io.iobase import Read from apache_beam.io.iobase import Write from apache_beam.transforms import PTransform @@ -116,6 +117,14 @@ class _TextSource(filebasedsource.FileBasedSource): start_offset = range_tracker.start_position() read_buffer = _TextSource.ReadBuffer('', 0) + next_record_start_position = -1 + + def split_points_unclaimed(stop_position): + return (0 if stop_position <= next_record_start_position + else iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + + range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed) + with self.open_file(file_name) as file_to_read: position_after_skipping_header_lines = self._skip_lines( file_to_read, read_buffer, @@ -153,10 +162,14 @@ class _TextSource(filebasedsource.FileBasedSource): if len(record) == 0 and num_bytes_to_next_record < 0: break + # Record separator must be larger than zero bytes. + assert num_bytes_to_next_record != 0 + if num_bytes_to_next_record > 0: + next_record_start_position += num_bytes_to_next_record + yield self._coder.decode(record) if num_bytes_to_next_record < 0: break - next_record_start_position += num_bytes_to_next_record def _find_separator_bounds(self, file_to_read, read_buffer): # Determines the start and end positions within 'read_buffer.data' of the @@ -220,7 +233,7 @@ class _TextSource(filebasedsource.FileBasedSource): def _read_record(self, file_to_read, read_buffer): # Returns a tuple containing the current_record and number of bytes to the - # next record starting from 'self._next_position_in_buffer'. If EOF is + # next record starting from 'read_buffer.position'. If EOF is # reached, returns a tuple containing the current record and -1. if read_buffer.position > self._buffer_size: http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/textio_test.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index f3ce843..04cf44c 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -27,6 +27,7 @@ import tempfile import unittest import apache_beam as beam +from apache_beam.io import iobase import apache_beam.io.source_test_utils as source_test_utils # Importing following private classes for testing. @@ -265,13 +266,25 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp): splits = [split for split in source.split(desired_bundle_size=100000)] assert len(splits) == 1 fraction_consumed_report = [] + split_points_report = [] range_tracker = splits[0].source.get_range_tracker( splits[0].start_position, splits[0].stop_position) for _ in splits[0].source.read(range_tracker): fraction_consumed_report.append(range_tracker.fraction_consumed()) + split_points_report.append(range_tracker.split_points()) self.assertEqual( [float(i) / 10 for i in range(0, 10)], fraction_consumed_report) + expected_split_points_report = [ + ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) + for i in range(1, 10)] + + # At last split point, the remaining split points callback returns 1 since + # the expected position of next record becomes equal to the stop position. + expected_split_points_report.append((9, 1)) + + self.assertEqual( + expected_split_points_report, split_points_report) def test_read_reentrant_without_splitting(self): file_name, expected_data = write_data(10) http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py ---------------------------------------------------------------------- diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py index 529d414..26ebe08 100644 --- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py +++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py @@ -136,7 +136,8 @@ class NativeSourceReader(object): class ReaderProgress(object): """A representation of how far a NativeSourceReader has read.""" - def __init__(self, position=None, percent_complete=None, remaining_time=None): + def __init__(self, position=None, percent_complete=None, remaining_time=None, + consumed_split_points=None, remaining_split_points=None): self._position = position @@ -149,6 +150,8 @@ class ReaderProgress(object): self._percent_complete = percent_complete self._remaining_time = remaining_time + self._consumed_split_points = consumed_split_points + self._remaining_split_points = remaining_split_points @property def position(self): @@ -172,6 +175,14 @@ class ReaderProgress(object): """Returns progress, represented as an estimated time remaining.""" return self._remaining_time + @property + def consumed_split_points(self): + return self._consumed_split_points + + @property + def remaining_split_points(self): + return self._remaining_split_points + class ReaderPosition(object): """A representation of position in an iteration of a 'NativeSourceReader'."""