This is an automated email from the ASF dual-hosted git repository. tvalentyn pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 6a573e4 BEAM-13189 Python TextIO: add escapechar feature. (#15901) 6a573e4 is described below commit 6a573e431a2b4e69fdd6a861c6f54517bbfa3175 Author: Eugene Nikolaiev <eugene.nikola...@gmail.com> AuthorDate: Thu Nov 11 10:32:33 2021 +0200 BEAM-13189 Python TextIO: add escapechar feature. (#15901) --- CHANGES.md | 1 + sdks/python/apache_beam/io/textio.py | 70 ++++++++++++--- sdks/python/apache_beam/io/textio_test.py | 139 +++++++++++++++++++++++++++++- 3 files changed, 198 insertions(+), 12 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index eab4aec..a25c1e8 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * X feature added (Java/Python) ([BEAM-X](https://issues.apache.org/jira/browse/BEAM-X)). * Add custom delimiters to Python TextIO reads ([BEAM-12730](https://issues.apache.org/jira/browse/BEAM-12730)). +* Add escapechar parameter to Python TextIO reads ([BEAM-13189](https://issues.apache.org/jira/browse/BEAM-13189)). * Splittable reading is enabled by default while reading data with ParquetIO ([BEAM-12070](https://issues.apache.org/jira/browse/BEAM-12070)). * DoFn Execution Time metrics added to Go ([BEAM-13001](https://issues.apache.org/jira/browse/BEAM-13001)). * Cross-bundle side input caching is now available in the Go SDK for runners that support the feature by setting the EnableSideInputCache hook ([BEAM-11097](https://issues.apache.org/jira/browse/BEAM-11097)). diff --git a/sdks/python/apache_beam/io/textio.py b/sdks/python/apache_beam/io/textio.py index 7f9ea6e..f53f9b3 100644 --- a/sdks/python/apache_beam/io/textio.py +++ b/sdks/python/apache_beam/io/textio.py @@ -100,7 +100,8 @@ class _TextSource(filebasedsource.FileBasedSource): validate=True, skip_header_lines=0, header_processor_fns=(None, None), - delimiter=None): + delimiter=None, + escapechar=None): """Initialize a _TextSource Args: @@ -116,6 +117,8 @@ class _TextSource(filebasedsource.FileBasedSource): delimiter (bytes) Optional: delimiter to split records. Must not self-overlap, because self-overlapping delimiters cause ambiguous parsing. + escapechar (bytes) Optional: a single byte to escape the records + delimiter, can also escape itself. Raises: ValueError: if skip_lines is negative. @@ -147,6 +150,11 @@ class _TextSource(filebasedsource.FileBasedSource): if self._is_self_overlapping(delimiter): raise ValueError('Delimiter must not self-overlap.') self._delimiter = delimiter + if escapechar is not None: + if not (isinstance(escapechar, bytes) and len(escapechar) == 1): + raise ValueError( + "escapechar must be bytes of size 1: '%s'" % escapechar) + self._escapechar = escapechar def display_data(self): parent_dd = super().display_data() @@ -176,7 +184,7 @@ class _TextSource(filebasedsource.FileBasedSource): start_offset = max(start_offset, position_after_processing_header_lines) if start_offset > position_after_processing_header_lines: # Seeking to one delimiter length before the start index and ignoring - # the current line. If start_position is at beginning if the line, that + # the current line. If start_position is at beginning of the line, that # line belongs to the current bundle, hence ignoring that is incorrect. # Seeking to one delimiter before prevents that. @@ -185,6 +193,16 @@ class _TextSource(filebasedsource.FileBasedSource): else: required_position = start_offset - 1 + if self._escapechar is not None: + # Need more bytes to check if the delimiter is escaped. + # Seek until the first escapechar if any. + while required_position > 0: + file_to_read.seek(required_position - 1) + if file_to_read.read(1) == self._escapechar: + required_position -= 1 + else: + break + file_to_read.seek(required_position) read_buffer.reset() sep_bounds = self._find_separator_bounds(file_to_read, read_buffer) @@ -277,11 +295,22 @@ class _TextSource(filebasedsource.FileBasedSource): if next_delim >= 0: if (self._delimiter is None and read_buffer.data[next_delim - 1:next_delim] == b'\r'): - # Accept both '\r\n' and '\n' as a default delimiter. - return (next_delim - 1, next_delim + 1) + if self._escapechar is not None and self._is_escaped(read_buffer, + next_delim - 1): + # Accept '\n' as a default delimiter, because '\r' is escaped. + return (next_delim, next_delim + 1) + else: + # Accept both '\r\n' and '\n' as a default delimiter. + return (next_delim - 1, next_delim + 1) else: - # Found a delimiter. Accepting that as the next delimiter. - return (next_delim, next_delim + delimiter_len) + if self._escapechar is not None and self._is_escaped(read_buffer, + next_delim): + # Skip an escaped delimiter. + current_pos = next_delim + delimiter_len + 1 + continue + else: + # Found a delimiter. Accepting that as the next delimiter. + return (next_delim, next_delim + delimiter_len) elif self._delimiter is not None: # Corner case: custom delimiter is truncated at the end of the buffer. @@ -362,6 +391,17 @@ class _TextSource(filebasedsource.FileBasedSource): return True return False + def _is_escaped(self, read_buffer, position): + # Returns True if byte at position is preceded with an odd number + # of escapechar bytes or False if preceded by 0 or even escapes + # (the even number means that all the escapes are escaped themselves). + escape_count = 0 + for current_pos in reversed(range(0, position)): + if read_buffer.data[current_pos:current_pos + 1] != self._escapechar: + break + escape_count += 1 + return escape_count % 2 == 1 + class _TextSourceWithFilename(_TextSource): def read_records(self, file_name, range_tracker): @@ -467,7 +507,8 @@ def _create_text_source( strip_trailing_newlines=None, coder=None, skip_header_lines=None, - delimiter=None): + delimiter=None, + escapechar=None): return _TextSource( file_pattern=file_pattern, min_bundle_size=min_bundle_size, @@ -476,7 +517,8 @@ def _create_text_source( coder=coder, validate=False, skip_header_lines=skip_header_lines, - delimiter=delimiter) + delimiter=delimiter, + escapechar=escapechar) class ReadAllFromText(PTransform): @@ -508,6 +550,7 @@ class ReadAllFromText(PTransform): skip_header_lines=0, with_filename=False, delimiter=None, + escapechar=None, **kwargs): """Initialize the ``ReadAllFromText`` transform. @@ -535,6 +578,8 @@ class ReadAllFromText(PTransform): delimiter (bytes) Optional: delimiter to split records. Must not self-overlap, because self-overlapping delimiters cause ambiguous parsing. + escapechar (bytes) Optional: a single byte to escape the records + delimiter, can also escape itself. """ super().__init__(**kwargs) source_from_file = partial( @@ -544,7 +589,8 @@ class ReadAllFromText(PTransform): strip_trailing_newlines=strip_trailing_newlines, coder=coder, skip_header_lines=skip_header_lines, - delimiter=delimiter) + delimiter=delimiter, + escapechar=escapechar) self._desired_bundle_size = desired_bundle_size self._min_bundle_size = min_bundle_size self._compression_type = compression_type @@ -585,6 +631,7 @@ class ReadFromText(PTransform): validate=True, skip_header_lines=0, delimiter=None, + escapechar=None, **kwargs): """Initialize the :class:`ReadFromText` transform. @@ -611,6 +658,8 @@ class ReadFromText(PTransform): delimiter (bytes) Optional: delimiter to split records. Must not self-overlap, because self-overlapping delimiters cause ambiguous parsing. + escapechar (bytes) Optional: a single byte to escape the records + delimiter, can also escape itself. """ super().__init__(**kwargs) @@ -622,7 +671,8 @@ class ReadFromText(PTransform): coder, validate=validate, skip_header_lines=skip_header_lines, - delimiter=delimiter) + delimiter=delimiter, + escapechar=escapechar) def expand(self, pvalue): return pvalue.pipeline | Read(self._source) diff --git a/sdks/python/apache_beam/io/textio_test.py b/sdks/python/apache_beam/io/textio_test.py index ae53234..f6e0dfb 100644 --- a/sdks/python/apache_beam/io/textio_test.py +++ b/sdks/python/apache_beam/io/textio_test.py @@ -163,14 +163,16 @@ class TextSourceTest(unittest.TestCase): expected_data, buffer_size=DEFAULT_NUM_RECORDS, compression=CompressionTypes.UNCOMPRESSED, - delimiter=None): + delimiter=None, + escapechar=None): # Since each record usually takes more than 1 byte, default buffer size is # smaller than the total size of the file. This is done to # increase test coverage for cases that hit the buffer boundary. - kwargs = {} if delimiter: kwargs['delimiter'] = delimiter + if escapechar: + kwargs['escapechar'] = escapechar source = TextSource( file_or_pattern, 0, @@ -1228,6 +1230,139 @@ class TextSourceTest(unittest.TestCase): assert len(expected_data) == 3 self._run_read_test(file_name, expected_data, buffer_size=6) + def test_read_escaped_lf(self): + file_name, expected_data = write_data( + self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne') + assert len(expected_data) == self.DEFAULT_NUM_RECORDS + self._run_read_test(file_name, expected_data, escapechar=b'\\') + + def test_read_escaped_crlf(self): + file_name, expected_data = write_data( + TextSource.DEFAULT_READ_BUFFER_SIZE, + eol=EOL.CRLF, + line_value=b'li\\\r\\\nne') + assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE + self._run_read_test(file_name, expected_data, escapechar=b'\\') + + def test_read_escaped_cr_before_not_escaped_lf(self): + file_name, expected_data_temp = write_data( + self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne') + expected_data = [] + for line in expected_data_temp: + expected_data += line.split("\n") + assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2 + self._run_read_test(file_name, expected_data, escapechar=b'\\') + + def test_read_escaped_custom_delimiter_crlf(self): + file_name, expected_data = write_data( + self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne') + assert len(expected_data) == self.DEFAULT_NUM_RECORDS + self._run_read_test( + file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\') + + def test_read_escaped_custom_delimiter(self): + file_name, expected_data = write_data( + TextSource.DEFAULT_READ_BUFFER_SIZE, + eol=EOL.CUSTOM_DELIMITER, + custom_delimiter=b'*|', + line_value=b'li\\*|ne') + assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE + self._run_read_test( + file_name, expected_data, delimiter=b'*|', escapechar=b'\\') + + def test_read_escaped_lf_at_buffer_edge(self): + file_name, expected_data = write_data(3, eol=EOL.LF, line_value=b'line\\\n') + assert len(expected_data) == 3 + self._run_read_test( + file_name, expected_data, buffer_size=5, escapechar=b'\\') + + def test_read_escaped_crlf_split_by_buffer(self): + file_name, expected_data = write_data( + 3, eol=EOL.CRLF, line_value=b'line\\\r\n') + assert len(expected_data) == 3 + self._run_read_test( + file_name, + expected_data, + buffer_size=6, + delimiter=b'\r\n', + escapechar=b'\\') + + def test_read_escaped_lf_after_splitting(self): + file_name, expected_data = write_data(3, line_value=b'line\\\n') + assert len(expected_data) == 3 + source = TextSource( + file_name, + 0, + CompressionTypes.UNCOMPRESSED, + True, + coders.StrUtf8Coder(), + escapechar=b'\\') + splits = list(source.split(desired_bundle_size=6)) + + reference_source_info = (source, None, None) + sources_info = ([(split.source, split.start_position, split.stop_position) + for split in splits]) + source_test_utils.assert_sources_equal_reference_source( + reference_source_info, sources_info) + + def test_read_escaped_lf_after_splitting_many(self): + file_name, expected_data = write_data( + 3, line_value=b'\\\\\\\\\\\n') # 5 escapes + assert len(expected_data) == 3 + source = TextSource( + file_name, + 0, + CompressionTypes.UNCOMPRESSED, + True, + coders.StrUtf8Coder(), + escapechar=b'\\') + splits = list(source.split(desired_bundle_size=6)) + + reference_source_info = (source, None, None) + sources_info = ([(split.source, split.start_position, split.stop_position) + for split in splits]) + source_test_utils.assert_sources_equal_reference_source( + reference_source_info, sources_info) + + def test_read_escaped_escapechar_after_splitting(self): + file_name, expected_data = write_data(3, line_value=b'line\\\\*|') + assert len(expected_data) == 3 + source = TextSource( + file_name, + 0, + CompressionTypes.UNCOMPRESSED, + True, + coders.StrUtf8Coder(), + delimiter=b'*|', + escapechar=b'\\') + splits = list(source.split(desired_bundle_size=8)) + + reference_source_info = (source, None, None) + sources_info = ([(split.source, split.start_position, split.stop_position) + for split in splits]) + source_test_utils.assert_sources_equal_reference_source( + reference_source_info, sources_info) + + def test_read_escaped_escapechar_after_splitting_many(self): + file_name, expected_data = write_data( + 3, line_value=b'\\\\\\\\\\\\*|') # 6 escapes + assert len(expected_data) == 3 + source = TextSource( + file_name, + 0, + CompressionTypes.UNCOMPRESSED, + True, + coders.StrUtf8Coder(), + delimiter=b'*|', + escapechar=b'\\') + splits = list(source.split(desired_bundle_size=8)) + + reference_source_info = (source, None, None) + sources_info = ([(split.source, split.start_position, split.stop_position) + for split in splits]) + source_test_utils.assert_sources_equal_reference_source( + reference_source_info, sources_info) + class TextSinkTest(unittest.TestCase): def setUp(self):