This is an automated email from the ASF dual-hosted git repository.

tvalentyn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 6a573e4  BEAM-13189 Python TextIO: add escapechar feature. (#15901)
6a573e4 is described below

commit 6a573e431a2b4e69fdd6a861c6f54517bbfa3175
Author: Eugene Nikolaiev <eugene.nikola...@gmail.com>
AuthorDate: Thu Nov 11 10:32:33 2021 +0200

    BEAM-13189 Python TextIO: add escapechar feature. (#15901)
---
 CHANGES.md                                |   1 +
 sdks/python/apache_beam/io/textio.py      |  70 ++++++++++++---
 sdks/python/apache_beam/io/textio_test.py | 139 +++++++++++++++++++++++++++++-
 3 files changed, 198 insertions(+), 12 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index eab4aec..a25c1e8 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
 
 * X feature added (Java/Python) 
([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)).
 * Add custom delimiters to Python TextIO reads 
([BEAM-12730](https://issues.apache.org/jira/browse/BEAM-12730)).
+* Add escapechar parameter to Python TextIO reads 
([BEAM-13189](https://issues.apache.org/jira/browse/BEAM-13189)).
 * Splittable reading is enabled by default while reading data with ParquetIO 
([BEAM-12070](https://issues.apache.org/jira/browse/BEAM-12070)).
 * DoFn Execution Time metrics added to Go 
([BEAM-13001](https://issues.apache.org/jira/browse/BEAM-13001)).
 * Cross-bundle side input caching is now available in the Go SDK for runners 
that support the feature by setting the EnableSideInputCache hook 
([BEAM-11097](https://issues.apache.org/jira/browse/BEAM-11097)).
diff --git a/sdks/python/apache_beam/io/textio.py 
b/sdks/python/apache_beam/io/textio.py
index 7f9ea6e..f53f9b3 100644
--- a/sdks/python/apache_beam/io/textio.py
+++ b/sdks/python/apache_beam/io/textio.py
@@ -100,7 +100,8 @@ class _TextSource(filebasedsource.FileBasedSource):
                validate=True,
                skip_header_lines=0,
                header_processor_fns=(None, None),
-               delimiter=None):
+               delimiter=None,
+               escapechar=None):
     """Initialize a _TextSource
 
     Args:
@@ -116,6 +117,8 @@ class _TextSource(filebasedsource.FileBasedSource):
       delimiter (bytes) Optional: delimiter to split records.
         Must not self-overlap, because self-overlapping delimiters cause
         ambiguous parsing.
+      escapechar (bytes) Optional: a single byte to escape the records
+        delimiter, can also escape itself.
     Raises:
       ValueError: if skip_lines is negative.
 
@@ -147,6 +150,11 @@ class _TextSource(filebasedsource.FileBasedSource):
       if self._is_self_overlapping(delimiter):
         raise ValueError('Delimiter must not self-overlap.')
     self._delimiter = delimiter
+    if escapechar is not None:
+      if not (isinstance(escapechar, bytes) and len(escapechar) == 1):
+        raise ValueError(
+            "escapechar must be bytes of size 1: '%s'" % escapechar)
+    self._escapechar = escapechar
 
   def display_data(self):
     parent_dd = super().display_data()
@@ -176,7 +184,7 @@ class _TextSource(filebasedsource.FileBasedSource):
       start_offset = max(start_offset, position_after_processing_header_lines)
       if start_offset > position_after_processing_header_lines:
         # Seeking to one delimiter length before the start index and ignoring
-        # the current line. If start_position is at beginning if the line, that
+        # the current line. If start_position is at beginning of the line, that
         # line belongs to the current bundle, hence ignoring that is incorrect.
         # Seeking to one delimiter before prevents that.
 
@@ -185,6 +193,16 @@ class _TextSource(filebasedsource.FileBasedSource):
         else:
           required_position = start_offset - 1
 
+        if self._escapechar is not None:
+          # Need more bytes to check if the delimiter is escaped.
+          # Seek until the first escapechar if any.
+          while required_position > 0:
+            file_to_read.seek(required_position - 1)
+            if file_to_read.read(1) == self._escapechar:
+              required_position -= 1
+            else:
+              break
+
         file_to_read.seek(required_position)
         read_buffer.reset()
         sep_bounds = self._find_separator_bounds(file_to_read, read_buffer)
@@ -277,11 +295,22 @@ class _TextSource(filebasedsource.FileBasedSource):
       if next_delim >= 0:
         if (self._delimiter is None and
             read_buffer.data[next_delim - 1:next_delim] == b'\r'):
-          # Accept both '\r\n' and '\n' as a default delimiter.
-          return (next_delim - 1, next_delim + 1)
+          if self._escapechar is not None and self._is_escaped(read_buffer,
+                                                               next_delim - 1):
+            # Accept '\n' as a default delimiter, because '\r' is escaped.
+            return (next_delim, next_delim + 1)
+          else:
+            # Accept both '\r\n' and '\n' as a default delimiter.
+            return (next_delim - 1, next_delim + 1)
         else:
-          # Found a delimiter. Accepting that as the next delimiter.
-          return (next_delim, next_delim + delimiter_len)
+          if self._escapechar is not None and self._is_escaped(read_buffer,
+                                                               next_delim):
+            # Skip an escaped delimiter.
+            current_pos = next_delim + delimiter_len + 1
+            continue
+          else:
+            # Found a delimiter. Accepting that as the next delimiter.
+            return (next_delim, next_delim + delimiter_len)
 
       elif self._delimiter is not None:
         # Corner case: custom delimiter is truncated at the end of the buffer.
@@ -362,6 +391,17 @@ class _TextSource(filebasedsource.FileBasedSource):
         return True
     return False
 
+  def _is_escaped(self, read_buffer, position):
+    # Returns True if byte at position is preceded with an odd number
+    # of escapechar bytes or False if preceded by 0 or even escapes
+    # (the even number means that all the escapes are escaped themselves).
+    escape_count = 0
+    for current_pos in reversed(range(0, position)):
+      if read_buffer.data[current_pos:current_pos + 1] != self._escapechar:
+        break
+      escape_count += 1
+    return escape_count % 2 == 1
+
 
 class _TextSourceWithFilename(_TextSource):
   def read_records(self, file_name, range_tracker):
@@ -467,7 +507,8 @@ def _create_text_source(
     strip_trailing_newlines=None,
     coder=None,
     skip_header_lines=None,
-    delimiter=None):
+    delimiter=None,
+    escapechar=None):
   return _TextSource(
       file_pattern=file_pattern,
       min_bundle_size=min_bundle_size,
@@ -476,7 +517,8 @@ def _create_text_source(
       coder=coder,
       validate=False,
       skip_header_lines=skip_header_lines,
-      delimiter=delimiter)
+      delimiter=delimiter,
+      escapechar=escapechar)
 
 
 class ReadAllFromText(PTransform):
@@ -508,6 +550,7 @@ class ReadAllFromText(PTransform):
       skip_header_lines=0,
       with_filename=False,
       delimiter=None,
+      escapechar=None,
       **kwargs):
     """Initialize the ``ReadAllFromText`` transform.
 
@@ -535,6 +578,8 @@ class ReadAllFromText(PTransform):
       delimiter (bytes) Optional: delimiter to split records.
         Must not self-overlap, because self-overlapping delimiters cause
         ambiguous parsing.
+      escapechar (bytes) Optional: a single byte to escape the records
+        delimiter, can also escape itself.
     """
     super().__init__(**kwargs)
     source_from_file = partial(
@@ -544,7 +589,8 @@ class ReadAllFromText(PTransform):
         strip_trailing_newlines=strip_trailing_newlines,
         coder=coder,
         skip_header_lines=skip_header_lines,
-        delimiter=delimiter)
+        delimiter=delimiter,
+        escapechar=escapechar)
     self._desired_bundle_size = desired_bundle_size
     self._min_bundle_size = min_bundle_size
     self._compression_type = compression_type
@@ -585,6 +631,7 @@ class ReadFromText(PTransform):
       validate=True,
       skip_header_lines=0,
       delimiter=None,
+      escapechar=None,
       **kwargs):
     """Initialize the :class:`ReadFromText` transform.
 
@@ -611,6 +658,8 @@ class ReadFromText(PTransform):
       delimiter (bytes) Optional: delimiter to split records.
         Must not self-overlap, because self-overlapping delimiters cause
         ambiguous parsing.
+      escapechar (bytes) Optional: a single byte to escape the records
+        delimiter, can also escape itself.
     """
 
     super().__init__(**kwargs)
@@ -622,7 +671,8 @@ class ReadFromText(PTransform):
         coder,
         validate=validate,
         skip_header_lines=skip_header_lines,
-        delimiter=delimiter)
+        delimiter=delimiter,
+        escapechar=escapechar)
 
   def expand(self, pvalue):
     return pvalue.pipeline | Read(self._source)
diff --git a/sdks/python/apache_beam/io/textio_test.py 
b/sdks/python/apache_beam/io/textio_test.py
index ae53234..f6e0dfb 100644
--- a/sdks/python/apache_beam/io/textio_test.py
+++ b/sdks/python/apache_beam/io/textio_test.py
@@ -163,14 +163,16 @@ class TextSourceTest(unittest.TestCase):
       expected_data,
       buffer_size=DEFAULT_NUM_RECORDS,
       compression=CompressionTypes.UNCOMPRESSED,
-      delimiter=None):
+      delimiter=None,
+      escapechar=None):
     # Since each record usually takes more than 1 byte, default buffer size is
     # smaller than the total size of the file. This is done to
     # increase test coverage for cases that hit the buffer boundary.
-
     kwargs = {}
     if delimiter:
       kwargs['delimiter'] = delimiter
+    if escapechar:
+      kwargs['escapechar'] = escapechar
     source = TextSource(
         file_or_pattern,
         0,
@@ -1228,6 +1230,139 @@ class TextSourceTest(unittest.TestCase):
     assert len(expected_data) == 3
     self._run_read_test(file_name, expected_data, buffer_size=6)
 
+  def test_read_escaped_lf(self):
+    file_name, expected_data = write_data(
+      self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne')
+    assert len(expected_data) == self.DEFAULT_NUM_RECORDS
+    self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+  def test_read_escaped_crlf(self):
+    file_name, expected_data = write_data(
+      TextSource.DEFAULT_READ_BUFFER_SIZE,
+      eol=EOL.CRLF,
+      line_value=b'li\\\r\\\nne')
+    assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
+    self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+  def test_read_escaped_cr_before_not_escaped_lf(self):
+    file_name, expected_data_temp = write_data(
+      self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
+    expected_data = []
+    for line in expected_data_temp:
+      expected_data += line.split("\n")
+    assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2
+    self._run_read_test(file_name, expected_data, escapechar=b'\\')
+
+  def test_read_escaped_custom_delimiter_crlf(self):
+    file_name, expected_data = write_data(
+      self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
+    assert len(expected_data) == self.DEFAULT_NUM_RECORDS
+    self._run_read_test(
+        file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\')
+
+  def test_read_escaped_custom_delimiter(self):
+    file_name, expected_data = write_data(
+      TextSource.DEFAULT_READ_BUFFER_SIZE,
+      eol=EOL.CUSTOM_DELIMITER,
+      custom_delimiter=b'*|',
+      line_value=b'li\\*|ne')
+    assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
+    self._run_read_test(
+        file_name, expected_data, delimiter=b'*|', escapechar=b'\\')
+
+  def test_read_escaped_lf_at_buffer_edge(self):
+    file_name, expected_data = write_data(3, eol=EOL.LF, 
line_value=b'line\\\n')
+    assert len(expected_data) == 3
+    self._run_read_test(
+        file_name, expected_data, buffer_size=5, escapechar=b'\\')
+
+  def test_read_escaped_crlf_split_by_buffer(self):
+    file_name, expected_data = write_data(
+      3, eol=EOL.CRLF, line_value=b'line\\\r\n')
+    assert len(expected_data) == 3
+    self._run_read_test(
+        file_name,
+        expected_data,
+        buffer_size=6,
+        delimiter=b'\r\n',
+        escapechar=b'\\')
+
+  def test_read_escaped_lf_after_splitting(self):
+    file_name, expected_data = write_data(3, line_value=b'line\\\n')
+    assert len(expected_data) == 3
+    source = TextSource(
+        file_name,
+        0,
+        CompressionTypes.UNCOMPRESSED,
+        True,
+        coders.StrUtf8Coder(),
+        escapechar=b'\\')
+    splits = list(source.split(desired_bundle_size=6))
+
+    reference_source_info = (source, None, None)
+    sources_info = ([(split.source, split.start_position, split.stop_position)
+                     for split in splits])
+    source_test_utils.assert_sources_equal_reference_source(
+        reference_source_info, sources_info)
+
+  def test_read_escaped_lf_after_splitting_many(self):
+    file_name, expected_data = write_data(
+      3, line_value=b'\\\\\\\\\\\n')  # 5 escapes
+    assert len(expected_data) == 3
+    source = TextSource(
+        file_name,
+        0,
+        CompressionTypes.UNCOMPRESSED,
+        True,
+        coders.StrUtf8Coder(),
+        escapechar=b'\\')
+    splits = list(source.split(desired_bundle_size=6))
+
+    reference_source_info = (source, None, None)
+    sources_info = ([(split.source, split.start_position, split.stop_position)
+                     for split in splits])
+    source_test_utils.assert_sources_equal_reference_source(
+        reference_source_info, sources_info)
+
+  def test_read_escaped_escapechar_after_splitting(self):
+    file_name, expected_data = write_data(3, line_value=b'line\\\\*|')
+    assert len(expected_data) == 3
+    source = TextSource(
+        file_name,
+        0,
+        CompressionTypes.UNCOMPRESSED,
+        True,
+        coders.StrUtf8Coder(),
+        delimiter=b'*|',
+        escapechar=b'\\')
+    splits = list(source.split(desired_bundle_size=8))
+
+    reference_source_info = (source, None, None)
+    sources_info = ([(split.source, split.start_position, split.stop_position)
+                     for split in splits])
+    source_test_utils.assert_sources_equal_reference_source(
+        reference_source_info, sources_info)
+
+  def test_read_escaped_escapechar_after_splitting_many(self):
+    file_name, expected_data = write_data(
+      3, line_value=b'\\\\\\\\\\\\*|')  # 6 escapes
+    assert len(expected_data) == 3
+    source = TextSource(
+        file_name,
+        0,
+        CompressionTypes.UNCOMPRESSED,
+        True,
+        coders.StrUtf8Coder(),
+        delimiter=b'*|',
+        escapechar=b'\\')
+    splits = list(source.split(desired_bundle_size=8))
+
+    reference_source_info = (source, None, None)
+    sources_info = ([(split.source, split.start_position, split.stop_position)
+                     for split in splits])
+    source_test_utils.assert_sources_equal_reference_source(
+        reference_source_info, sources_info)
+
 
 class TextSinkTest(unittest.TestCase):
   def setUp(self):

Reply via email to