This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 11c9b42  ARROW-2090: [Python] Add context methods to ParquetWriter
11c9b42 is described below

commit 11c9b42a9f9c8b69fa33b1972bc195ea5e5f47ae
Author: Alec Posney <[email protected]>
AuthorDate: Mon Feb 5 16:56:31 2018 -0500

    ARROW-2090: [Python] Add context methods to ParquetWriter
    
    Added the `__enter__` and `__exit__` methods to the ParquetWriter
    Class. This allows you to write code in the following style:
    
    ```
    with ParquetWriter(foo, schema) as writer:
        writer.write_table(table)
    ```
    
    And the ParquetWriter context object will handle calling the close
    method for you when the with block exits.  It propagates errors
    in line with previous behavior. I have also updated the module
    level `write_table` method to use the new `with` style.
    And it maintains it's existing behavior of trying to remove
    a partial file if an exception is encountered.
    
    The reason for this change, was a bug I encountered while
    using the `ParquetWriter` object, where I wasn't closing
    the writer correctly in the event of an exception, which resulted
    in partially written parquet files. Adding the ability to use
    `with` syntax will reduce the chances of a similar misuse
    while still maintaing backwards compatibility.
    
    Author: Alec Posney <[email protected]>
    
    Closes #1559 from Posnet/ARROW-2090 and squashes the following commits:
    
    9a24433d [Alec Posney] ARROW-2090: [Python] Add context methods to 
ParquetWriter
---
 python/doc/source/parquet.rst        |  9 ++++++
 python/pyarrow/parquet.py            | 33 ++++++++++---------
 python/pyarrow/tests/test_parquet.py | 63 ++++++++++++++++++++++++++++++++++++
 3 files changed, 90 insertions(+), 15 deletions(-)

diff --git a/python/doc/source/parquet.rst b/python/doc/source/parquet.rst
index d466ba1..ac56520 100644
--- a/python/doc/source/parquet.rst
+++ b/python/doc/source/parquet.rst
@@ -139,11 +139,20 @@ We can similarly write a Parquet file with multiple row 
groups by using
    pf2 = pq.ParquetFile('example2.parquet')
    pf2.num_row_groups
 
+Alternatively python ``with`` syntax can also be use:
+
+.. ipython:: python
+
+   with pq.ParquetWriter('example3.parquet', table.schema) as writer:
+       for i in range(3):
+           writer.write_table(table)
+
 .. ipython:: python
    :suppress:
 
    !rm example.parquet
    !rm example2.parquet
+   !rm example3.parquet
 
 Compression, Encoding, and File Compatibility
 ---------------------------------------------
diff --git a/python/pyarrow/parquet.py b/python/pyarrow/parquet.py
index 3a0924a..8820b6b 100644
--- a/python/pyarrow/parquet.py
+++ b/python/pyarrow/parquet.py
@@ -292,6 +292,14 @@ schema : arrow Schema
         if getattr(self, 'is_open', False):
             self.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, *args, **kwargs):
+        self.close()
+        # return false since we want to propagate exceptions
+        return False
+
     def write_table(self, table, row_group_size=None):
         if self.schema_changed:
             table = _sanitize_table(table, self.schema, self.flavor)
@@ -932,29 +940,24 @@ def write_table(table, where, row_group_size=None, 
version='1.0',
                 flavor=None, **kwargs):
     row_group_size = kwargs.pop('chunk_size', row_group_size)
 
-    writer = None
     try:
-        writer = ParquetWriter(
-            where, table.schema,
-            version=version,
-            flavor=flavor,
-            use_dictionary=use_dictionary,
-            coerce_timestamps=coerce_timestamps,
-            compression=compression,
-            use_deprecated_int96_timestamps=use_deprecated_int96_timestamps,
-            **kwargs)
-        writer.write_table(table, row_group_size=row_group_size)
+        with ParquetWriter(
+                where, table.schema,
+                version=version,
+                flavor=flavor,
+                use_dictionary=use_dictionary,
+                coerce_timestamps=coerce_timestamps,
+                compression=compression,
+                use_deprecated_int96_timestamps= 
use_deprecated_int96_timestamps, # noqa
+                **kwargs) as writer:
+            writer.write_table(table, row_group_size=row_group_size)
     except Exception:
-        if writer is not None:
-            writer.close()
         if isinstance(where, six.string_types):
             try:
                 os.remove(where)
             except os.error:
                 pass
         raise
-    else:
-        writer.close()
 
 
 write_table.__doc__ = """
diff --git a/python/pyarrow/tests/test_parquet.py 
b/python/pyarrow/tests/test_parquet.py
index 7c2edb3..c49f3d3 100644
--- a/python/pyarrow/tests/test_parquet.py
+++ b/python/pyarrow/tests/test_parquet.py
@@ -1673,3 +1673,66 @@ def test_decimal_roundtrip_negative_scale(tmpdir):
     result_table = _read_table(string_filename)
     result = result_table.to_pandas()
     tm.assert_frame_equal(result, expected)
+
+
+@parquet
+def test_parquet_writer_context_obj(tmpdir):
+
+    import pyarrow.parquet as pq
+
+    df = _test_dataframe(100)
+    df['unique_id'] = 0
+
+    arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+    out = pa.BufferOutputStream()
+
+    with pq.ParquetWriter(out, arrow_table.schema, version='2.0') as writer:
+
+        frames = []
+        for i in range(10):
+            df['unique_id'] = i
+            arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+            writer.write_table(arrow_table)
+
+            frames.append(df.copy())
+
+    buf = out.get_result()
+    result = _read_table(pa.BufferReader(buf))
+
+    expected = pd.concat(frames, ignore_index=True)
+    tm.assert_frame_equal(result.to_pandas(), expected)
+
+
+@parquet
+def test_parquet_writer_context_obj_with_exception(tmpdir):
+
+    import pyarrow.parquet as pq
+
+    df = _test_dataframe(100)
+    df['unique_id'] = 0
+
+    arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+    out = pa.BufferOutputStream()
+    error_text = 'Artificial Error'
+
+    try:
+        with pq.ParquetWriter(out,
+                              arrow_table.schema,
+                              version='2.0') as writer:
+
+            frames = []
+            for i in range(10):
+                df['unique_id'] = i
+                arrow_table = pa.Table.from_pandas(df, preserve_index=False)
+                writer.write_table(arrow_table)
+                frames.append(df.copy())
+                if i == 5:
+                    raise ValueError(error_text)
+    except Exception as e:
+        assert str(e) == error_text
+
+    buf = out.get_result()
+    result = _read_table(pa.BufferReader(buf))
+
+    expected = pd.concat(frames, ignore_index=True)
+    tm.assert_frame_equal(result.to_pandas(), expected)

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to