Repository: incubator-beam
Updated Branches:
  refs/heads/python-sdk fe1f39609 -> 6e6d89d48


Document that source objects should not be mutated.

Updates  textio._TextSource so that it does not get mutated while reading.

Updates source_test_utils so that sources objects do not get cloned while 
testing. This could help to catch sources that erroneously get modified while 
reading.

Adds reentracy tests for text and Avro sources.


Project: http://git-wip-us.apache.org/repos/asf/incubator-beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-beam/commit/2ab8d62a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-beam/tree/2ab8d62a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-beam/diff/2ab8d62a

Branch: refs/heads/python-sdk
Commit: 2ab8d62ac48481a52fa04c704491f3a5889de27c
Parents: fe1f396
Author: Chamikara Jayalath <chamik...@google.com>
Authored: Wed Oct 12 16:51:20 2016 -0700
Committer: Robert Bradshaw <rober...@google.com>
Committed: Tue Oct 18 12:08:41 2016 -0700

----------------------------------------------------------------------
 sdks/python/apache_beam/io/avroio_test.py       | 28 ++++++-
 sdks/python/apache_beam/io/iobase.py            |  8 ++
 sdks/python/apache_beam/io/source_test_utils.py | 13 +---
 sdks/python/apache_beam/io/textio.py            | 80 +++++++++++---------
 sdks/python/apache_beam/io/textio_test.py       | 37 +++++++++
 5 files changed, 121 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ab8d62a/sdks/python/apache_beam/io/avroio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/avroio_test.py 
b/sdks/python/apache_beam/io/avroio_test.py
index 1c96d72..eb2c81c 100644
--- a/sdks/python/apache_beam/io/avroio_test.py
+++ b/sdks/python/apache_beam/io/avroio_test.py
@@ -110,7 +110,7 @@ class TestAvro(unittest.TestCase):
     return file_name_prefix + os.path.sep + 'mytemp*'
 
   def _run_avro_test(self, pattern, desired_bundle_size, perform_splitting,
-                     expected_result):
+                     expected_result, test_reentrancy=False):
     source = AvroSource(pattern)
 
     read_records = []
@@ -128,9 +128,23 @@ class TestAvro(unittest.TestCase):
           (split.source, split.start_position, split.stop_position)
           for split in splits
       ]
+      if test_reentrancy:
+        for source_info in sources_info:
+          reader_iter = source_info[0].read(source_info[0].get_range_tracker(
+              source_info[1], source_info[2]))
+          try:
+            next(reader_iter)
+          except StopIteration:
+            # Ignoring empty bundle
+            pass
+
       source_test_utils.assertSourcesEqualReferenceSource((source, None, None),
                                                           sources_info)
     else:
+      if test_reentrancy:
+        reader_iter = source.read(source.get_range_tracker(None, None))
+        next(reader_iter)
+
       read_records = source_test_utils.readFromSource(source, None, None)
       self.assertItemsEqual(expected_result, read_records)
 
@@ -144,6 +158,18 @@ class TestAvro(unittest.TestCase):
     expected_result = self.RECORDS
     self._run_avro_test(file_name, 100, True, expected_result)
 
+  def test_read_reentrant_without_splitting(self):
+    file_name = self._write_data()
+    expected_result = self.RECORDS
+    self._run_avro_test(file_name, None, False, expected_result,
+                        test_reentrancy=True)
+
+  def test_read_reantrant_with_splitting(self):
+    file_name = self._write_data()
+    expected_result = self.RECORDS
+    self._run_avro_test(file_name, 100, True, expected_result,
+                        test_reentrancy=True)
+
   def test_read_without_splitting_multiple_blocks(self):
     file_name = self._write_data(count=12000)
     expected_result = self.RECORDS * 2000

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ab8d62a/sdks/python/apache_beam/io/iobase.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/iobase.py 
b/sdks/python/apache_beam/io/iobase.py
index 8239e26..edd3524 100644
--- a/sdks/python/apache_beam/io/iobase.py
+++ b/sdks/python/apache_beam/io/iobase.py
@@ -91,6 +91,14 @@ class BoundedSource(object):
       positions passed to the method ``get_range_tracker()`` are ``None``
   (2) Method read() will be invoked with the ``RangeTracker`` obtained in the
       previous step.
+
+  **Mutability**
+
+  A ``BoundedSource`` object should be fully mutated before being submitted
+  for reading. A ``BoundedSource`` object should not be mutated while
+  its methods (for example, ``read()``) are being invoked by a runner. Runner
+  implementations may invoke methods of ``BoundedSource`` objects through
+  multi-threaded and/or re-entrant execution modes.
   """
 
   def estimate_size(self):

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ab8d62a/sdks/python/apache_beam/io/source_test_utils.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/source_test_utils.py 
b/sdks/python/apache_beam/io/source_test_utils.py
index 13b1e91..33ab083 100644
--- a/sdks/python/apache_beam/io/source_test_utils.py
+++ b/sdks/python/apache_beam/io/source_test_utils.py
@@ -48,7 +48,6 @@ from collections import namedtuple
 import logging
 
 from multiprocessing.pool import ThreadPool
-from apache_beam.internal import pickler
 from apache_beam.io import iobase
 
 
@@ -81,7 +80,7 @@ def readFromSource(source, start_position=None, 
stop_position=None):
   values = []
   range_tracker = source.get_range_tracker(start_position, stop_position)
   assert isinstance(range_tracker, iobase.RangeTracker)
-  reader = _copy_source(source).read(range_tracker)
+  reader = source.read(range_tracker)
   for value in reader:
     values.append(value)
 
@@ -173,7 +172,7 @@ def assertSplitAtFractionBehavior(source, 
num_items_to_read_before_split,
     source while the second value of the tuple will be '-1'.
   """
   assert isinstance(source, iobase.BoundedSource)
-  expected_items = readFromSource(_copy_source(source), None, None)
+  expected_items = readFromSource(source, None, None)
   return _assertSplitAtFractionBehavior(
       source, expected_items, num_items_to_read_before_split, split_fraction,
       expected_outcome)
@@ -186,7 +185,7 @@ def _assertSplitAtFractionBehavior(
   range_tracker = source.get_range_tracker(start_position, stop_position)
   assert isinstance(range_tracker, iobase.RangeTracker)
   current_items = []
-  reader = _copy_source(source).read(range_tracker)
+  reader = source.read(range_tracker)
   # Reading 'num_items_to_read_before_split' items.
   reader_iter = iter(reader)
   for _ in range(num_items_to_read_before_split):
@@ -536,7 +535,7 @@ def _assertSplitAtFractionConcurrent(
 
   range_tracker = source.get_range_tracker(None, None)
   stop_position_before_split = range_tracker.stop_position()
-  reader = _copy_source(source).read(range_tracker)
+  reader = source.read(range_tracker)
   reader_iter = iter(reader)
 
   current_items = []
@@ -575,7 +574,3 @@ def _assertSplitAtFractionConcurrent(
       primary_range, residual_range, split_fraction)
 
   return res[1] > 0
-
-
-def _copy_source(source):
-  return pickler.loads(pickler.dumps(source))

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ab8d62a/sdks/python/apache_beam/io/textio.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index f1f5a25..dcaceef 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -41,14 +41,20 @@ class _TextSource(filebasedsource.FileBasedSource):
 
   DEFAULT_READ_BUFFER_SIZE = 8192
 
+  class ReadBuffer(object):
+    # A buffer that gives the buffered data and next position in the
+    # buffer that should be read.
+
+    def __init__(self, data, position):
+      self.data = data
+      self.position = position
+
   def __init__(self, file_pattern, min_bundle_size,
                compression_type, strip_trailing_newlines, coder,
                buffer_size=DEFAULT_READ_BUFFER_SIZE):
     super(_TextSource, self).__init__(file_pattern, min_bundle_size,
                                       compression_type=compression_type)
-    self._buffer = ''
-    self._next_position_in_buffer = 0
-    self._file = None
+
     self._strip_trailing_newlines = strip_trailing_newlines
     self._compression_type = compression_type
     self._coder = coder
@@ -57,7 +63,9 @@ class _TextSource(filebasedsource.FileBasedSource):
   def read_records(self, file_name, range_tracker):
     start_offset = range_tracker.start_position()
 
-    self._file = self.open_file(file_name)
+    read_buffer = _TextSource.ReadBuffer('', 0)
+    file_to_read = self.open_file(file_name)
+
     try:
       if start_offset > 0:
         # Seeking to one position before the start index and ignoring the
@@ -65,98 +73,100 @@ class _TextSource(filebasedsource.FileBasedSource):
         # belongs to the current bundle, hence ignoring that is incorrect.
         # Seeking to one byte before prevents that.
 
-        self._file.seek(start_offset - 1)
-        sep_bounds = self._find_separator_bounds()
+        file_to_read.seek(start_offset - 1)
+        sep_bounds = self._find_separator_bounds(file_to_read, read_buffer)
         if not sep_bounds:
           # Could not find a separator after (start_offset - 1). This means 
that
           # none of the records within the file belongs to the current source.
           return
 
         _, sep_end = sep_bounds
-        self._buffer = self._buffer[sep_end:]
+        read_buffer.data = read_buffer.data[sep_end:]
         next_record_start_position = start_offset -1 + sep_end
       else:
         next_record_start_position = 0
 
       while range_tracker.try_claim(next_record_start_position):
-        record, num_bytes_to_next_record = self._read_record()
+        record, num_bytes_to_next_record = self._read_record(file_to_read,
+                                                             read_buffer)
         yield self._coder.decode(record)
         if num_bytes_to_next_record < 0:
           break
         next_record_start_position += num_bytes_to_next_record
     finally:
-      self._file.close()
+      file_to_read.close()
 
-  def _find_separator_bounds(self):
-    # Determines the start and end positions within 'self._buffer' of the next
-    # separator starting from 'self._next_position_in_buffer'.
+  def _find_separator_bounds(self, file_to_read, read_buffer):
+    # Determines the start and end positions within 'read_buffer.data' of the
+    # next separator starting from position 'read_buffer.position'.
     # Currently supports following separators.
     # * '\n'
     # * '\r\n'
     # This method may increase the size of buffer but it will not decrease the
     # size of it.
 
-    current_pos = self._next_position_in_buffer
+    current_pos = read_buffer.position
 
     while True:
-      if current_pos >= len(self._buffer):
+      if current_pos >= len(read_buffer.data):
         # Ensuring that there are enough bytes to determine if there is a '\n'
         # at current_pos.
-        if not self._try_to_ensure_num_bytes_in_buffer(current_pos + 1):
+        if not self._try_to_ensure_num_bytes_in_buffer(
+            file_to_read, read_buffer, current_pos + 1):
           return
 
       # Using find() here is more efficient than a linear scan of the byte
       # array.
-      next_lf = self._buffer.find('\n', current_pos)
+      next_lf = read_buffer.data.find('\n', current_pos)
       if next_lf >= 0:
-        if self._buffer[next_lf - 1] == '\r':
+        if read_buffer.data[next_lf - 1] == '\r':
           return (next_lf - 1, next_lf + 1)
         else:
           return (next_lf, next_lf + 1)
 
-      current_pos = len(self._buffer)
+      current_pos = len(read_buffer.data)
 
-  def _try_to_ensure_num_bytes_in_buffer(self, num_bytes):
+  def _try_to_ensure_num_bytes_in_buffer(
+      self, file_to_read, read_buffer, num_bytes):
     # Tries to ensure that there are at least num_bytes bytes in the buffer.
     # Returns True if this can be fulfilled, returned False if this cannot be
     # fulfilled due to reaching EOF.
-    while len(self._buffer) < num_bytes:
-      read_data = self._file.read(self._buffer_size)
+    while len(read_buffer.data) < num_bytes:
+      read_data = file_to_read.read(self._buffer_size)
       if not read_data:
         return False
 
-      self._buffer += read_data
+      read_buffer.data += read_data
 
     return True
 
-  def _read_record(self):
+  def _read_record(self, file_to_read, read_buffer):
     # Returns a tuple containing the current_record and number of bytes to the
     # next record starting from 'self._next_position_in_buffer'. If EOF is
     # reached, returns a tuple containing the current record and -1.
 
-    if self._next_position_in_buffer > self._buffer_size:
-      # Buffer is too large. Truncating it and adjusting
-      # self._next_position_in_buffer.
-      self._buffer = self._buffer[self._next_position_in_buffer:]
-      self._next_position_in_buffer = 0
+    if read_buffer.position > self._buffer_size:
+      # read_buffer is too large. Truncating and adjusting it.
+      read_buffer.data = read_buffer.data[read_buffer.position:]
+      read_buffer.position = 0
 
-    record_start_position_in_buffer = self._next_position_in_buffer
-    sep_bounds = self._find_separator_bounds()
-    self._next_position_in_buffer = sep_bounds[1] if sep_bounds else len(
-        self._buffer)
+    record_start_position_in_buffer = read_buffer.position
+    sep_bounds = self._find_separator_bounds(file_to_read, read_buffer)
+    read_buffer.position = sep_bounds[1] if sep_bounds else len(
+        read_buffer.data)
 
     if not sep_bounds:
       # Reached EOF. Bytes up to the EOF is the next record. Returning '-1' for
       # the starting position of the next record.
-      return (self._buffer[record_start_position_in_buffer:], -1)
+      return (read_buffer.data[record_start_position_in_buffer:], -1)
 
     if self._strip_trailing_newlines:
       # Current record should not contain the separator.
-      return (self._buffer[record_start_position_in_buffer:sep_bounds[0]],
+      return (read_buffer.data[record_start_position_in_buffer:sep_bounds[0]],
               sep_bounds[1] - record_start_position_in_buffer)
     else:
       # Current record should contain the separator.
-      return (self._buffer[record_start_position_in_buffer:sep_bounds[1]],
+      return (read_buffer.data[record_start_position_in_buffer:sep_bounds[1]],
               sep_bounds[1] - record_start_position_in_buffer)
 
 

http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/2ab8d62a/sdks/python/apache_beam/io/textio_test.py
----------------------------------------------------------------------
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index 109506a..90ff3cc 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -198,6 +198,43 @@ class TextSourceTest(unittest.TestCase):
     self.assertEqual(
         [float(i) / 10 for i in range(0, 10)], fraction_consumed_report)
 
+  def test_read_reentrant_without_splitting(self):
+    file_name, expected_data = write_data(10)
+    assert len(expected_data) == 10
+    source1 = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
+                         coders.StrUtf8Coder())
+    reader_iter = source1.read(source1.get_range_tracker(None, None))
+    next(reader_iter)
+    next(reader_iter)
+
+    source2 = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
+                         coders.StrUtf8Coder())
+    source_test_utils.assertSourcesEqualReferenceSource((source1, None, None),
+                                                        [(source2, None, 
None)])
+
+  def test_read_reentrant_after_splitting(self):
+    file_name, expected_data = write_data(10)
+    assert len(expected_data) == 10
+    source = TextSource(file_name, 0, CompressionTypes.UNCOMPRESSED, True,
+                        coders.StrUtf8Coder())
+    splits1 = [split for split in source.split(desired_bundle_size=100000)]
+    assert len(splits1) == 1
+    reader_iter = splits1[0].source.read(
+        splits1[0].source.get_range_tracker(
+            splits1[0].start_position, splits1[0].stop_position))
+    next(reader_iter)
+    next(reader_iter)
+
+    splits2 = [split for split in source.split(desired_bundle_size=100000)]
+    assert len(splits2) == 1
+    source_test_utils.assertSourcesEqualReferenceSource(
+        (splits1[0].source,
+         splits1[0].start_position,
+         splits1[0].stop_position),
+        [(splits2[0].source,
+          splits2[0].start_position,
+          splits2[0].stop_position)])
+
   def test_dynamic_work_rebalancing(self):
     file_name, expected_data = write_data(15)
     assert len(expected_data) == 15

Reply via email to