This is an automated email from the ASF dual-hosted git repository.

jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new a63ead769a GH-34884: [Python]: Support pickling pyarrow.dataset 
PartitioningFactory objects (#36550)
a63ead769a is described below

commit a63ead769ae5c7e4f2a32f528bcb39c5aae2ab49
Author: Joris Van den Bossche <[email protected]>
AuthorDate: Mon Jul 10 08:32:40 2023 +0200

    GH-34884: [Python]: Support pickling pyarrow.dataset PartitioningFactory 
objects (#36550)
    
    ### Rationale for this change
    
    https://github.com/apache/arrow/pull/36462 already added support for 
pickling Partitioning objects, but not yet the PartitioningFactory objects.
    
    The problem for PartitioningFactory is that we currently don't really 
expose the full class hierarchy in python, just the base class 
PartitioningFactory. We also don't expose creating those factory objects, 
except through the `discover` methods of the Partitioning classes.
    I think it would be nice to keep this minimal binding, but that means if we 
want to make them serializable with pickle, we need another way to do that (and 
if we don't want to add custom code for serialization on the C++ side).
    
    In this PR, I went for the route of essentially storing the constructor 
(the discover static method) and the arguments that were passed to the 
constructor, on the factory object, so we can use this info for pickling. Not 
the nicest code, but the simplest solution I could think of.
    
    ### Are these changes tested?
    
    Yes
    * Closes: #34884
    
    Authored-by: Joris Van den Bossche <[email protected]>
    Signed-off-by: Joris Van den Bossche <[email protected]>
---
 python/pyarrow/_dataset.pxd          |  5 ++-
 python/pyarrow/_dataset.pyx          | 39 +++++++++++++++++++++---
 python/pyarrow/_dataset_parquet.pyx  |  2 +-
 python/pyarrow/tests/test_dataset.py | 59 +++++++++++++++++++++++++-----------
 4 files changed, 80 insertions(+), 25 deletions(-)

diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd
index d626b42e23..210e555800 100644
--- a/python/pyarrow/_dataset.pxd
+++ b/python/pyarrow/_dataset.pxd
@@ -160,11 +160,14 @@ cdef class PartitioningFactory(_Weakrefable):
     cdef:
         shared_ptr[CPartitioningFactory] wrapped
         CPartitioningFactory* factory
+        object constructor
+        object options
 
     cdef init(self, const shared_ptr[CPartitioningFactory]& sp)
 
     @staticmethod
-    cdef wrap(const shared_ptr[CPartitioningFactory]& sp)
+    cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
+              object constructor, object options)
 
     cdef inline shared_ptr[CPartitioningFactory] unwrap(self)
 
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 2ab8ffb798..c5f0a663a8 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -2374,16 +2374,22 @@ cdef class PartitioningFactory(_Weakrefable):
         self.factory = sp.get()
 
     @staticmethod
-    cdef wrap(const shared_ptr[CPartitioningFactory]& sp):
+    cdef wrap(const shared_ptr[CPartitioningFactory]& sp,
+              object constructor, object options):
         cdef PartitioningFactory self = PartitioningFactory.__new__(
             PartitioningFactory
         )
         self.init(sp)
+        self.constructor = constructor
+        self.options = options
         return self
 
     cdef inline shared_ptr[CPartitioningFactory] unwrap(self):
         return self.wrapped
 
+    def __reduce__(self):
+        return self.constructor, self.options
+
     @property
     def type_name(self):
         return frombytes(self.factory.type_name())
@@ -2454,6 +2460,10 @@ cdef class KeyValuePartitioning(Partitioning):
         return res
 
 
+def _constructor_directory_partitioning_factory(*args):
+    return DirectoryPartitioning.discover(*args)
+
+
 cdef class DirectoryPartitioning(KeyValuePartitioning):
     """
     A Partitioning based on a specified Schema.
@@ -2571,7 +2581,15 @@ cdef class DirectoryPartitioning(KeyValuePartitioning):
         c_options.segment_encoding = _get_segment_encoding(segment_encoding)
 
         return PartitioningFactory.wrap(
-            CDirectoryPartitioning.MakeFactory(c_field_names, c_options))
+            CDirectoryPartitioning.MakeFactory(c_field_names, c_options),
+            _constructor_directory_partitioning_factory,
+            (field_names, infer_dictionary, max_partition_dictionary_size,
+             schema, segment_encoding)
+        )
+
+
+def _constructor_hive_partitioning_factory(*args):
+    return HivePartitioning.discover(*args)
 
 
 cdef class HivePartitioning(KeyValuePartitioning):
@@ -2714,7 +2732,15 @@ cdef class HivePartitioning(KeyValuePartitioning):
         c_options.segment_encoding = _get_segment_encoding(segment_encoding)
 
         return PartitioningFactory.wrap(
-            CHivePartitioning.MakeFactory(c_options))
+            CHivePartitioning.MakeFactory(c_options),
+            _constructor_hive_partitioning_factory,
+            (infer_dictionary, max_partition_dictionary_size, null_fallback,
+             schema, segment_encoding),
+        )
+
+
+def _constructor_filename_partitioning_factory(*args):
+    return FilenamePartitioning.discover(*args)
 
 
 cdef class FilenamePartitioning(KeyValuePartitioning):
@@ -2823,7 +2849,10 @@ cdef class FilenamePartitioning(KeyValuePartitioning):
         c_options.segment_encoding = _get_segment_encoding(segment_encoding)
 
         return PartitioningFactory.wrap(
-            CFilenamePartitioning.MakeFactory(c_field_names, c_options))
+            CFilenamePartitioning.MakeFactory(c_field_names, c_options),
+            _constructor_filename_partitioning_factory,
+            (field_names, infer_dictionary, schema, segment_encoding)
+        )
 
 
 cdef class DatasetFactory(_Weakrefable):
@@ -2988,7 +3017,7 @@ cdef class FileSystemFactoryOptions(_Weakrefable):
         c_factory = self.options.partitioning.factory()
         if c_factory.get() == nullptr:
             return None
-        return PartitioningFactory.wrap(c_factory)
+        return PartitioningFactory.wrap(c_factory, None, None)
 
     @partitioning_factory.setter
     def partitioning_factory(self, PartitioningFactory value):
diff --git a/python/pyarrow/_dataset_parquet.pyx 
b/python/pyarrow/_dataset_parquet.pyx
index bc4786b9cd..4ad0caec30 100644
--- a/python/pyarrow/_dataset_parquet.pyx
+++ b/python/pyarrow/_dataset_parquet.pyx
@@ -811,7 +811,7 @@ cdef class ParquetFactoryOptions(_Weakrefable):
         c_factory = self.options.partitioning.factory()
         if c_factory.get() == nullptr:
             return None
-        return PartitioningFactory.wrap(c_factory)
+        return PartitioningFactory.wrap(c_factory, None, None)
 
     @partitioning_factory.setter
     def partitioning_factory(self, PartitioningFactory value):
diff --git a/python/pyarrow/tests/test_dataset.py 
b/python/pyarrow/tests/test_dataset.py
index 814454861e..2f9b6a0922 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -1642,12 +1642,15 @@ def test_fragments_repr(tempdir, dataset):
 
 
 @pytest.mark.parquet
-def test_partitioning_factory(mockfs):
[email protected](
+    "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory(mockfs, pickled):
     paths_or_selector = fs.FileSelector('subdir', recursive=True)
     format = ds.ParquetFileFormat()
 
     options = ds.FileSystemFactoryOptions('subdir')
     partitioning_factory = ds.DirectoryPartitioning.discover(['group', 'key'])
+    partitioning_factory = pickled(partitioning_factory)
     assert isinstance(partitioning_factory, ds.PartitioningFactory)
     options.partitioning_factory = partitioning_factory
 
@@ -1673,13 +1676,16 @@ def test_partitioning_factory(mockfs):
 
 @pytest.mark.parquet
 @pytest.mark.parametrize('infer_dictionary', [False, True])
-def test_partitioning_factory_dictionary(mockfs, infer_dictionary):
[email protected](
+    "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_dictionary(mockfs, infer_dictionary, pickled):
     paths_or_selector = fs.FileSelector('subdir', recursive=True)
     format = ds.ParquetFileFormat()
     options = ds.FileSystemFactoryOptions('subdir')
 
-    options.partitioning_factory = ds.DirectoryPartitioning.discover(
+    partitioning_factory = ds.DirectoryPartitioning.discover(
         ['group', 'key'], infer_dictionary=infer_dictionary)
+    options.partitioning_factory = pickled(partitioning_factory)
 
     factory = ds.FileSystemDatasetFactory(
         mockfs, paths_or_selector, format, options)
@@ -1703,7 +1709,9 @@ def test_partitioning_factory_dictionary(mockfs, 
infer_dictionary):
         assert inferred_schema.field('key').type == pa.string()
 
 
-def test_partitioning_factory_segment_encoding():
[email protected](
+    "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_segment_encoding(pickled):
     mockfs = fs._MockFileSystem()
     format = ds.IpcFileFormat()
     schema = pa.schema([("i64", pa.int64())])
@@ -1726,8 +1734,9 @@ def test_partitioning_factory_segment_encoding():
     # Directory
     selector = fs.FileSelector("directory", recursive=True)
     options = ds.FileSystemFactoryOptions("directory")
-    options.partitioning_factory = ds.DirectoryPartitioning.discover(
+    partitioning_factory = ds.DirectoryPartitioning.discover(
         schema=partition_schema)
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     inferred_schema = factory.inspect()
     assert inferred_schema == full_schema
@@ -1736,24 +1745,27 @@ def test_partitioning_factory_segment_encoding():
     })
     assert actual[0][0].as_py() == 1620086400
 
-    options.partitioning_factory = ds.DirectoryPartitioning.discover(
+    partitioning_factory = ds.DirectoryPartitioning.discover(
         ["date", "string"], segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("date") == "2021-05-04 00%3A00%3A00") &
         (ds.field("string") == "%24"))
 
-    options.partitioning = ds.DirectoryPartitioning(
+    partitioning = ds.DirectoryPartitioning(
         string_partition_schema, segment_encoding="none")
+    options.partitioning = pickled(partitioning)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("date") == "2021-05-04 00%3A00%3A00") &
         (ds.field("string") == "%24"))
 
-    options.partitioning_factory = ds.DirectoryPartitioning.discover(
+    partitioning_factory = ds.DirectoryPartitioning.discover(
         schema=partition_schema, segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     with pytest.raises(pa.ArrowInvalid,
                        match="Could not cast segments for partition field"):
@@ -1762,8 +1774,9 @@ def test_partitioning_factory_segment_encoding():
     # Hive
     selector = fs.FileSelector("hive", recursive=True)
     options = ds.FileSystemFactoryOptions("hive")
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         schema=partition_schema)
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     inferred_schema = factory.inspect()
     assert inferred_schema == full_schema
@@ -1772,8 +1785,9 @@ def test_partitioning_factory_segment_encoding():
     })
     assert actual[0][0].as_py() == 1620086400
 
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
@@ -1788,15 +1802,18 @@ def test_partitioning_factory_segment_encoding():
         (ds.field("date") == "2021-05-04 00%3A00%3A00") &
         (ds.field("string") == "%24"))
 
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         schema=partition_schema, segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     with pytest.raises(pa.ArrowInvalid,
                        match="Could not cast segments for partition field"):
         inferred_schema = factory.inspect()
 
 
-def test_partitioning_factory_hive_segment_encoding_key_encoded():
[email protected](
+    "pickled", [lambda x: x, lambda x: pickle.loads(pickle.dumps(x))])
+def test_partitioning_factory_hive_segment_encoding_key_encoded(pickled):
     mockfs = fs._MockFileSystem()
     format = ds.IpcFileFormat()
     schema = pa.schema([("i64", pa.int64())])
@@ -1825,8 +1842,9 @@ def 
test_partitioning_factory_hive_segment_encoding_key_encoded():
     # Hive
     selector = fs.FileSelector("hive", recursive=True)
     options = ds.FileSystemFactoryOptions("hive")
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         schema=partition_schema)
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     inferred_schema = factory.inspect()
     assert inferred_schema == full_schema
@@ -1835,40 +1853,45 @@ def 
test_partitioning_factory_hive_segment_encoding_key_encoded():
     })
     assert actual[0][0].as_py() == 1620086400
 
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         segment_encoding="uri")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("test'; date") == "2021-05-04 00:00:00") &
         (ds.field("test';[ string'") == "$"))
 
-    options.partitioning = ds.HivePartitioning(
+    partitioning = ds.HivePartitioning(
         string_partition_schema, segment_encoding="uri")
+    options.partitioning = pickled(partitioning)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("test'; date") == "2021-05-04 00:00:00") &
         (ds.field("test';[ string'") == "$"))
 
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
         (ds.field("test%27%3B%5B%20string%27") == "%24"))
 
-    options.partitioning = ds.HivePartitioning(
+    partitioning = ds.HivePartitioning(
         string_partition_schema_en, segment_encoding="none")
+    options.partitioning = pickled(partitioning)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     fragments = list(factory.finish().get_fragments())
     assert fragments[0].partition_expression.equals(
         (ds.field("test%27%3B%20date") == "2021-05-04 00%3A00%3A00") &
         (ds.field("test%27%3B%5B%20string%27") == "%24"))
 
-    options.partitioning_factory = ds.HivePartitioning.discover(
+    partitioning_factory = ds.HivePartitioning.discover(
         schema=partition_schema_en, segment_encoding="none")
+    options.partitioning_factory = pickled(partitioning_factory)
     factory = ds.FileSystemDatasetFactory(mockfs, selector, format, options)
     with pytest.raises(pa.ArrowInvalid,
                        match="Could not cast segments for partition field"):

Reply via email to