Repository: beam
Updated Branches:
  refs/heads/master de9e8528c -> 8f5f19d11


Updates Python SDK source API so that sources can report limited parallelism 
signals.

With this update Python BoundedSource/RangeTracker API can report consumed and 
remaining number of split points while performing a source read operations, 
similar to Java SDK sources.

These signals can be used by runner implementations, for example, to perform 
scaling decisions.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/97a7ae44
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/97a7ae44
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/97a7ae44

Branch: refs/heads/master
Commit: 97a7ae449e1ccf6b08a0ee0bc2fc0a1b49924f1f
Parents: de9e852
Author: Chamikara Jayalath <chamik...@google.com>
Authored: Wed Jan 4 19:10:09 2017 -0800
Committer: Robert Bradshaw <rober...@gmail.com>
Committed: Thu Mar 2 17:19:08 2017 -0800

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio.py            |  33 +++-
 sdks/python/apache_beam/io/avroio_test.py       |  31 ++++
 sdks/python/apache_beam/io/iobase.py            | 157 ++++++++++++++++++-
 sdks/python/apache_beam/io/range_trackers.py    |  41 ++++-
 .../apache_beam/io/range_trackers_test.py       |  52 ++++++
 sdks/python/apache_beam/io/textio.py            |  17 +-
 sdks/python/apache_beam/io/textio_test.py       |  13 ++
 .../runners/dataflow/native_io/iobase.py        |  13 +-
 8 files changed, 342 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/avroio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio.py 
b/sdks/python/apache_beam/io/avroio.py
index 5dab651..ab98530 100644
--- a/sdks/python/apache_beam/io/avroio.py
+++ b/sdks/python/apache_beam/io/avroio.py
@@ -28,6 +28,7 @@ from avro import schema
 import apache_beam as beam
 from apache_beam.io import filebasedsource
 from apache_beam.io import fileio
+from apache_beam.io import iobase
 from apache_beam.io.iobase import Read
 from apache_beam.transforms import PTransform
 
@@ -135,6 +136,7 @@ class _AvroUtils(object):
       ValueError: If the block cannot be read properly because the file doesn't
         match the specification.
     """
+    offset = f.tell()
     decoder = avroio.BinaryDecoder(f)
     num_records = decoder.read_long()
     block_size = decoder.read_long()
@@ -144,7 +146,8 @@ class _AvroUtils(object):
       raise ValueError('Unexpected sync marker (actual "%s" vs expected "%s"). 
'
                        'Maybe the underlying avro file is corrupted?',
                        sync_marker, expected_sync_marker)
-    return _AvroBlock(block_bytes, num_records, codec, schema)
+    size = f.tell() - offset
+    return _AvroBlock(block_bytes, num_records, codec, schema, offset, size)
 
   @staticmethod
   def advance_file_past_next_sync_marker(f, sync_marker):
@@ -172,13 +175,22 @@ class _AvroUtils(object):
 class _AvroBlock(object):
   """Represents a block of an Avro file."""
 
-  def __init__(self, block_bytes, num_records, codec, schema_string):
+  def __init__(self, block_bytes, num_records, codec, schema_string,
+               offset, size):
     # Decompress data early on (if needed) and thus decrease the number of
     # parallel copies of the data in memory at any given in time during
     # block iteration.
     self._decompressed_block_bytes = self._decompress_bytes(block_bytes, codec)
     self._num_records = num_records
     self._schema = schema.parse(schema_string)
+    self._offset = offset
+    self._size = size
+
+  def size(self):
+    return self._size
+
+  def offset(self):
+    return self._offset
 
   @staticmethod
   def _decompress_bytes(data, codec):
@@ -232,12 +244,26 @@ class _AvroSource(filebasedsource.FileBasedSource):
   """
 
   def read_records(self, file_name, range_tracker):
+    next_block_start = -1
+
+    def split_points_unclaimed(stop_position):
+      if next_block_start >= stop_position:
+        # Next block starts at or after the suggested stop position. Hence
+        # there will not be split points to be claimed for the range ending at
+        # suggested stop position.
+        return 0
+
+      return iobase.RangeTracker.SPLIT_POINTS_UNKNOWN
+
+    range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed)
+
     start_offset = range_tracker.start_position()
     if start_offset is None:
       start_offset = 0
 
     with self.open_file(file_name) as f:
-      codec, schema_string, sync_marker = 
_AvroUtils.read_meta_data_from_file(f)
+      codec, schema_string, sync_marker = _AvroUtils.read_meta_data_from_file(
+          f)
 
       # We have to start at current position if previous bundle ended at the
       # end of a sync marker.
@@ -248,6 +274,7 @@ class _AvroSource(filebasedsource.FileBasedSource):
       while range_tracker.try_claim(f.tell()):
         block = _AvroUtils.read_block_from_file(f, codec, schema_string,
                                                 sync_marker)
+        next_block_start = block.offset() + block.size()
         for record in block.records():
           yield record
 

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py 
b/sdks/python/apache_beam/io/avroio_test.py
index 233ab69..8b14443 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -22,6 +22,7 @@ import tempfile
 import unittest
 
 import apache_beam as beam
+from apache_beam.io import iobase
 from apache_beam.io import avroio
 from apache_beam.io import filebasedsource
 from apache_beam.io import source_test_utils
@@ -248,6 +249,36 @@ class TestAvro(unittest.TestCase):
     expected_result = self.RECORDS * 2000
     self._run_avro_test(file_name, 10000, True, expected_result)
 
+  def test_split_points(self):
+    file_name = self._write_data(count=12000)
+    source = AvroSource(file_name)
+
+    splits = [
+        split
+        for split in source.split(desired_bundle_size=float('inf'))
+    ]
+    assert len(splits) == 1
+
+    range_tracker = splits[0].source.get_range_tracker(
+        splits[0].start_position, splits[0].stop_position)
+
+    split_points_report = []
+
+    for _ in splits[0].source.read(range_tracker):
+      split_points_report.append(range_tracker.split_points())
+
+    # There are a total of three blocks. Each block has more than 10 records.
+
+    # When reading records of the first block, range_tracker.split_points()
+    # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+    self.assertEquals(
+        split_points_report[:10],
+        [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)
+
+    # When reading records of last block, range_tracker.split_points() should
+    # return (2, 1)
+    self.assertEquals(split_points_report[-10:], [(2, 1)] * 10)
+
   def test_read_without_splitting_compressed_deflate(self):
     file_name = self._write_data(codec='deflate')
     expected_result = self.RECORDS

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/iobase.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/iobase.py 
b/sdks/python/apache_beam/io/iobase.py
index e139a24..bd40a3e 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -300,6 +300,8 @@ class RangeTracker(object):
   the current reader and by a reader of the task starting at 43).
   """
 
+  SPLIT_POINTS_UNKNOWN = object()
+
   def start_position(self):
     """Returns the starting position of the current range, inclusive."""
     raise NotImplementedError(type(self))
@@ -317,8 +319,8 @@ class RangeTracker(object):
 
     ** Thread safety **
 
-    This method along with several other methods of this class may be invoked 
by
-    multiple threads, hence must be made thread-safe, e.g. by using a single
+    Methods of the class ``RangeTracker`` including this method may get invoked
+    by different threads, hence must be made thread-safe, e.g. by using a 
single
     lock object.
 
     Args:
@@ -352,8 +354,8 @@ class RangeTracker(object):
 
     ** Thread safety **
 
-    This method along with several other methods of this class may be invoked 
by
-    multiple threads, hence must be made thread-safe, e.g. by using a single
+    Methods of the class ``RangeTracker`` including this method may get invoked
+    by different threads, hence must be made thread-safe, e.g. by using a 
single
     lock object.
 
     Args:
@@ -387,8 +389,8 @@ class RangeTracker(object):
 
     ** Thread safety **
 
-    This method along with several other methods of this class may be invoked 
by
-    multiple threads, hence must be made thread-safe, e.g. by using a single
+    Methods of the class ``RangeTracker`` including this method may get invoked
+    by different threads, hence must be made thread-safe, e.g. by using a 
single
     lock object.
 
     Args:
@@ -405,8 +407,8 @@ class RangeTracker(object):
 
     ** Thread safety **
 
-    This method along with several other methods of this class may be invoked 
by
-    multiple threads, hence must be made thread-safe, e.g. by using a single
+    Methods of the class ``RangeTracker`` including this method may get invoked
+    by different threads, hence must be made thread-safe, e.g. by using a 
single
     lock object.
 
     Returns:
@@ -416,6 +418,145 @@ class RangeTracker(object):
     """
     raise NotImplementedError
 
+  def split_points(self):
+    """Gives the number of split points consumed and remaining.
+
+    For a ``RangeTracker`` used by a ``BoundedSource`` (within a
+    ``BoundedSource.read()`` invocation) this method produces a 2-tuple that
+    gives the number of split points consumed by the ``BoundedSource`` and the
+    number of split points remaining within the range of the ``RangeTracker``
+    that has not been consumed by the ``BoundedSource``.
+
+    More specifically, given that the position of the current record being read
+    by ``BoundedSource`` is current_position this method produces a tuple that
+    consists of
+    (1) number of split points in the range [self.start_position(),
+    current_position) without including the split point that is currently being
+    consumed. This represents the total amount of parallelism in the consumed
+    part of the source.
+    (2) number of split points within the range
+    [current_position, self.stop_position()) including the split point that is
+    currently being consumed. This represents the total amount of parallelism 
in
+    the unconsumed part of the source.
+
+    Methods of the class ``RangeTracker`` including this method may get invoked
+    by different threads, hence must be made thread-safe, e.g. by using a 
single
+    lock object.
+
+    ** General information about consumed and remaining number of split
+       points returned by this method. **
+
+      * Before a source read (``BoundedSource.read()`` invocation) claims the
+        first split point, number of consumed split points is 0. This condition
+        holds independent of whether the input is "splittable". A splittable
+        source is a source that has more than one split point.
+      * Any source read that has only claimed one split point has 0 consumed
+        split points since the first split point is the current split point and
+        is still being processed. This condition holds independent of whether
+        the input is splittable.
+      * For an empty source read which never invokes
+        ``RangeTracker.try_claim()``, the consumed number of split points is 0.
+        This condition holds independent of whether the input is splittable.
+      * For a source read which has invoked ``RangeTracker.try_claim()`` n
+        times, the consumed number of split points is  n -1.
+      * If a ``BoundedSource`` sets a callback through function
+        ``set_split_points_unclaimed_callback()``, ``RangeTracker`` can use 
that
+        callback when determining remaining number of split points.
+      * Remaining split points should include the split point that is currently
+        being consumed by the source read. Hence if the above callback returns
+        an integer value n, remaining number of split points should be (n + 1).
+      * After last split point is claimed remaining split points becomes 1,
+        because this unfinished read itself represents an  unfinished split
+        point.
+      * After all records of the source has been consumed, remaining number of
+        split points becomes 0 and consumed number of split points becomes 
equal
+        to the total number of split points within the range being read by the
+        source. This method does not address this condition and will continue 
to
+        report number of consumed split points as
+        ("total number of split points" - 1) and number of remaining split
+        points as 1. A runner that performs the reading of the source can
+        detect when all records have been consumed and adjust remaining and
+        consumed number of split points accordingly.
+
+    ** Examples **
+
+    (1) A "perfectly splittable" input which can be read in parallel down to 
the
+        individual records.
+
+        Consider a perfectly splittable input that consists of 50 split points.
+
+      * Before a source read (``BoundedSource.read()`` invocation) claims the
+        first split point, number of consumed split points is 0 number of
+        remaining split points is 50.
+      * After claiming first split point, consumed number of split points is 0
+        and remaining number of split is 50.
+      * After claiming split point #30, consumed number of split points is 29
+        and remaining number of split points is 21.
+      * After claiming all 50 split points, consumed number of split points is
+        49 and remaining number of split points is 1.
+
+    (2) a "block-compressed" file format such as ``avroio``, in which a block 
of
+        records has to be read as a whole, but different blocks can be read in
+        parallel.
+
+        Consider a block compressed input that consists of 5 blocks.
+
+      * Before a source read (``BoundedSource.read()`` invocation) claims the
+        first split point (first block), number of consumed split points is 0
+        number of remaining split points is 5.
+      * After claiming first split point, consumed number of split points is 0
+        and remaining number of split is 5.
+      * After claiming split point #3, consumed number of split points is 2
+        and remaining number of split points is 3.
+      * After claiming all 5 split points, consumed number of split points is
+        4 and remaining number of split points is 1.
+
+    (3) an "unsplittable" input such as a cursor in a database or a gzip
+        compressed file.
+
+        Such an input is considered to have only a single split point. Number 
of
+        consumed split points is always 0 and number of remaining split points
+        is always 1.
+
+    By default ``RangeTracker` returns ``RangeTracker.SPLIT_POINTS_UNKNOWN`` 
for
+    both consumed and remaining number of split points, which indicates that 
the
+    number of split points consumed and remaining is unknown.
+
+    Returns:
+      A pair that gives consumed and remaining number of split points. Consumed
+      number of split points should be an integer larger than or equal to zero
+      or ``RangeTracker.SPLIT_POINTS_UNKNOWN``. Remaining number of split 
points
+      should be an integer larger than zero or
+      ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
+    """
+    return (RangeTracker.SPLIT_POINTS_UNKNOWN,
+            RangeTracker.SPLIT_POINTS_UNKNOWN)
+
+  def set_split_points_unclaimed_callback(self, callback):
+    """Sets a callback for determining the unclaimed number of split points.
+
+    By invoking this function, a ``BoundedSource`` can set a callback function
+    that may get invoked by the ``RangeTracker`` to determine the number of
+    unclaimed split points. A split point is unclaimed if
+    ``RangeTracker.try_claim()`` method has not been successfully invoked for
+    that particular split point. The callback function accepts a single
+    parameter, a stop position for the BoundedSource (stop_position). If the
+    record currently being consumed by the ``BoundedSource`` is at position
+    current_position, callback should return the number of split points within
+    the range (current_position, stop_position). Note that, this should not
+    include the split point that is currently being consumed by the source.
+
+    This function must be implemented by subclasses before being used.
+
+    Args:
+      callback: a function that takes a single parameter, a stop position,
+                and returns unclaimed number of split points for the source 
read
+                operation that is calling this function. Value returned from
+                callback should be either an integer larger than or equal to
+                zero or ``RangeTracker.SPLIT_POINTS_UNKNOWN``.
+    """
+    raise NotImplementedError
+
 
 class Sink(HasDisplayData):
   """A resource that can be written to using the ``df.io.Write`` transform.

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/range_trackers.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers.py 
b/sdks/python/apache_beam/io/range_trackers.py
index 4c8f7eb..059b6ca 100644
--- a/sdks/python/apache_beam/io/range_trackers.py
+++ b/sdks/python/apache_beam/io/range_trackers.py
@@ -55,6 +55,9 @@ class OffsetRangeTracker(iobase.RangeTracker):
     self._offset_of_last_split_point = -1
     self._lock = threading.Lock()
 
+    self._split_points_seen = 0
+    self._split_points_unclaimed_callback = None
+
   def start_position(self):
     return self._start_offset
 
@@ -106,6 +109,7 @@ class OffsetRangeTracker(iobase.RangeTracker):
         return False
       self._offset_of_last_split_point = record_start
       self._last_record_start = record_start
+      self._split_points_seen += 1
       return True
 
   def set_current_position(self, record_start):
@@ -167,6 +171,24 @@ class OffsetRangeTracker(iobase.RangeTracker):
     return int(math.ceil(self.start_position() + fraction * (
         self.stop_position() - self.start_position())))
 
+  def split_points(self):
+    with self._lock:
+      split_points_consumed = (
+          0 if self._split_points_seen == 0 else self._split_points_seen - 1)
+      split_points_unclaimed = (
+          self._split_points_unclaimed_callback(self.stop_position())
+          if self._split_points_unclaimed_callback
+          else iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+      split_points_remaining = (
+          iobase.RangeTracker.SPLIT_POINTS_UNKNOWN if
+          split_points_unclaimed == iobase.RangeTracker.SPLIT_POINTS_UNKNOWN
+          else (split_points_unclaimed + 1))
+
+      return (split_points_consumed, split_points_remaining)
+
+  def set_split_points_unclaimed_callback(self, callback):
+    self._split_points_unclaimed_callback = callback
+
 
 class GroupedShuffleRangeTracker(iobase.RangeTracker):
   """A 'RangeTracker' for positions used by'GroupedShuffleReader'.
@@ -184,6 +206,7 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker):
     self._decoded_stop_pos = decoded_stop_pos
     self._decoded_last_group_start = None
     self._last_group_was_at_a_split_point = False
+    self._split_points_seen = 0
     self._lock = threading.Lock()
 
   def start_position(self):
@@ -240,6 +263,7 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker):
 
       self._decoded_last_group_start = decoded_group_start
       self._last_group_was_at_a_split_point = True
+      self._split_points_seen += 1
       return True
 
   def set_current_position(self, decoded_group_start):
@@ -285,6 +309,14 @@ class GroupedShuffleRangeTracker(iobase.RangeTracker):
                        ' consumed due to positions being opaque strings'
                        ' that are interpreted by the service')
 
+  def split_points(self):
+    with self._lock:
+      splits_points_consumed = (
+          0 if self._split_points_seen <= 1 else (self._split_points_seen - 1))
+
+      return (splits_points_consumed,
+              iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+
 
 class OrderedPositionRangeTracker(iobase.RangeTracker):
   """
@@ -380,7 +412,7 @@ class UnsplittableRangeTracker(iobase.RangeTracker):
       range_tracker: a ``RangeTracker`` to which all method calls expect calls
       to ``try_split()`` will be delegated.
     """
-    assert range_tracker
+    assert isinstance(range_tracker, iobase.RangeTracker)
     self._range_tracker = range_tracker
 
   def start_position(self):
@@ -404,6 +436,13 @@ class UnsplittableRangeTracker(iobase.RangeTracker):
   def fraction_consumed(self):
     return self._range_tracker.fraction_consumed()
 
+  def split_points(self):
+    # An unsplittable range only contains a single split point.
+    return (0, 1)
+
+  def set_split_points_unclaimed_callback(self, callback):
+    self._range_tracker.set_split_points_unclaimed_callback(callback)
+
 
 class LexicographicKeyRangeTracker(OrderedPositionRangeTracker):
   """

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/range_trackers_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/range_trackers_test.py 
b/sdks/python/apache_beam/io/range_trackers_test.py
index b80d1f3..edb6386 100644
--- a/sdks/python/apache_beam/io/range_trackers_test.py
+++ b/sdks/python/apache_beam/io/range_trackers_test.py
@@ -24,6 +24,7 @@ import math
 import unittest
 
 
+from apache_beam.io import iobase
 from apache_beam.io import range_trackers
 
 
@@ -158,6 +159,35 @@ class OffsetRangeTrackerTest(unittest.TestCase):
     with self.assertRaises(Exception):
       tracker.try_claim(110)
 
+  def test_try_split_points(self):
+    tracker = range_trackers.OffsetRangeTracker(100, 400)
+
+    def dummy_callback(stop_position):
+      return int(stop_position / 5)
+
+    tracker.set_split_points_unclaimed_callback(dummy_callback)
+
+    self.assertEqual(tracker.split_points(),
+                     (0, 81))
+    self.assertTrue(tracker.try_claim(120))
+    self.assertEqual(tracker.split_points(),
+                     (0, 81))
+    self.assertTrue(tracker.try_claim(140))
+    self.assertEqual(tracker.split_points(),
+                     (1, 81))
+    tracker.try_split(200)
+    self.assertEqual(tracker.split_points(),
+                     (1, 41))
+    self.assertTrue(tracker.try_claim(150))
+    self.assertEqual(tracker.split_points(),
+                     (2, 41))
+    self.assertTrue(tracker.try_claim(180))
+    self.assertEqual(tracker.split_points(),
+                     (3, 41))
+    self.assertFalse(tracker.try_claim(210))
+    self.assertEqual(tracker.split_points(),
+                     (3, 41))
+
 
 class GroupedShuffleRangeTrackerTest(unittest.TestCase):
 
@@ -319,6 +349,28 @@ class GroupedShuffleRangeTrackerTest(unittest.TestCase):
     self.assertFalse(tracker.try_claim(
         self.bytes_to_position([3, 2, 1])))
 
+  def test_split_points(self):
+    tracker = range_trackers.GroupedShuffleRangeTracker(
+        self.bytes_to_position([1, 0, 0]),
+        self.bytes_to_position([5, 0, 0]))
+    self.assertEqual(tracker.split_points(),
+                     (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+    self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 3])))
+    self.assertEqual(tracker.split_points(),
+                     (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+    self.assertTrue(tracker.try_claim(self.bytes_to_position([1, 2, 5])))
+    self.assertEqual(tracker.split_points(),
+                     (1, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+    self.assertTrue(tracker.try_claim(self.bytes_to_position([3, 6, 8])))
+    self.assertEqual(tracker.split_points(),
+                     (2, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+    self.assertTrue(tracker.try_claim(self.bytes_to_position([4, 255, 255])))
+    self.assertEqual(tracker.split_points(),
+                     (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+    self.assertFalse(tracker.try_claim(self.bytes_to_position([5, 1, 0])))
+    self.assertEqual(tracker.split_points(),
+                     (3, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN))
+
 
 class OrderedPositionRangeTrackerTest(unittest.TestCase):
 

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index 19980cb..5bb1a9d 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -24,6 +24,7 @@ import logging
 from apache_beam import coders
 from apache_beam.io import filebasedsource
 from apache_beam.io import fileio
+from apache_beam.io import iobase
 from apache_beam.io.iobase import Read
 from apache_beam.io.iobase import Write
 from apache_beam.transforms import PTransform
@@ -116,6 +117,14 @@ class _TextSource(filebasedsource.FileBasedSource):
     start_offset = range_tracker.start_position()
     read_buffer = _TextSource.ReadBuffer('', 0)
 
+    next_record_start_position = -1
+
+    def split_points_unclaimed(stop_position):
+      return (0 if stop_position <= next_record_start_position
+              else iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+
+    range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed)
+
     with self.open_file(file_name) as file_to_read:
       position_after_skipping_header_lines = self._skip_lines(
           file_to_read, read_buffer,
@@ -153,10 +162,14 @@ class _TextSource(filebasedsource.FileBasedSource):
         if len(record) == 0 and num_bytes_to_next_record < 0:
           break
 
+        # Record separator must be larger than zero bytes.
+        assert num_bytes_to_next_record != 0
+        if num_bytes_to_next_record > 0:
+          next_record_start_position += num_bytes_to_next_record
+
         yield self._coder.decode(record)
         if num_bytes_to_next_record < 0:
           break
-        next_record_start_position += num_bytes_to_next_record
 
   def _find_separator_bounds(self, file_to_read, read_buffer):
     # Determines the start and end positions within 'read_buffer.data' of the
@@ -220,7 +233,7 @@ class _TextSource(filebasedsource.FileBasedSource):
 
   def _read_record(self, file_to_read, read_buffer):
     # Returns a tuple containing the current_record and number of bytes to the
-    # next record starting from 'self._next_position_in_buffer'. If EOF is
+    # next record starting from 'read_buffer.position'. If EOF is
     # reached, returns a tuple containing the current record and -1.
 
     if read_buffer.position > self._buffer_size:

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/io/textio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index f3ce843..04cf44c 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -27,6 +27,7 @@ import tempfile
 import unittest
 
 import apache_beam as beam
+from apache_beam.io import iobase
 import apache_beam.io.source_test_utils as source_test_utils
 
 # Importing following private classes for testing.
@@ -265,13 +266,25 @@ class TextSourceTest(_TestCaseWithTempDirCleanUp):
     splits = [split for split in source.split(desired_bundle_size=100000)]
     assert len(splits) == 1
     fraction_consumed_report = []
+    split_points_report = []
     range_tracker = splits[0].source.get_range_tracker(
         splits[0].start_position, splits[0].stop_position)
     for _ in splits[0].source.read(range_tracker):
       fraction_consumed_report.append(range_tracker.fraction_consumed())
+      split_points_report.append(range_tracker.split_points())
 
     self.assertEqual(
         [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
+    expected_split_points_report = [
+        ((i - 1), iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
+        for i in range(1, 10)]
+
+    # At last split point, the remaining split points callback returns 1 since
+    # the expected position of next record becomes equal to the stop position.
+    expected_split_points_report.append((9, 1))
+
+    self.assertEqual(
+        expected_split_points_report, split_points_report)
 
   def test_read_reentrant_without_splitting(self):
     file_name, expected_data = write_data(10)

http://git-wip-us.apache.org/repos/asf/beam/blob/97a7ae44/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py 
b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
index 529d414..26ebe08 100644
--- a/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
+++ b/sdks/python/apache_beam/runners/dataflow/native_io/iobase.py
@@ -136,7 +136,8 @@ class NativeSourceReader(object):
 class ReaderProgress(object):
   """A representation of how far a NativeSourceReader has read."""
 
-  def __init__(self, position=None, percent_complete=None, 
remaining_time=None):
+  def __init__(self, position=None, percent_complete=None, remaining_time=None,
+               consumed_split_points=None, remaining_split_points=None):
 
     self._position = position
 
@@ -149,6 +150,8 @@ class ReaderProgress(object):
     self._percent_complete = percent_complete
 
     self._remaining_time = remaining_time
+    self._consumed_split_points = consumed_split_points
+    self._remaining_split_points = remaining_split_points
 
   @property
   def position(self):
@@ -172,6 +175,14 @@ class ReaderProgress(object):
     """Returns progress, represented as an estimated time remaining."""
     return self._remaining_time
 
+  @property
+  def consumed_split_points(self):
+    return self._consumed_split_points
+
+  @property
+  def remaining_split_points(self):
+    return self._remaining_split_points
+
 
 class ReaderPosition(object):
   """A representation of position in an iteration of a 'NativeSourceReader'."""

Reply via email to