This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new ff28c76  ARROW-2024: [Python] Remove torch serialization from default 
serialization context.
ff28c76 is described below

commit ff28c7647c1910f1a0d1d97b8ba68b2b554e5ce1
Author: Robert Nishihara <[email protected]>
AuthorDate: Thu Feb 1 12:17:46 2018 -0500

    ARROW-2024: [Python] Remove torch serialization from default serialization 
context.
    
    See discussion in #1223.
    
    Author: Robert Nishihara <[email protected]>
    
    Closes #1538 from robertnishihara/dontimportpytorch and squashes the 
following commits:
    
    09a8cfa1 [Robert Nishihara] Fix bug.
    8992e0fb [Robert Nishihara] Fix.
    83ffb70f [Robert Nishihara] Remove torch serialization from default 
serialization context.
---
 python/pyarrow/__init__.py                  |  5 ++-
 python/pyarrow/serialization.py             | 66 +++++++++++++++++------------
 python/pyarrow/tests/test_convert_pandas.py |  2 +-
 python/pyarrow/tests/test_serialization.py  | 56 +++++++++++++-----------
 4 files changed, 73 insertions(+), 56 deletions(-)

diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index a245fe6..8b3cba9 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -124,9 +124,10 @@ from pyarrow.ipc import (Message, MessageReader,
 
 localfs = LocalFileSystem.get_instance()
 
-from pyarrow.serialization import (_default_serialization_context,
+from pyarrow.serialization import (default_serialization_context,
                                    pandas_serialization_context,
-                                   register_default_serialization_handlers)
+                                   register_default_serialization_handlers,
+                                   register_torch_serialization_handlers)
 
 import pyarrow.types as types
 
diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 61f2e83..c8b72b7 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -22,7 +22,8 @@ import sys
 import numpy as np
 
 from pyarrow.compat import builtin_pickle
-from pyarrow.lib import _default_serialization_context, frombuffer
+from pyarrow.lib import (SerializationContext, _default_serialization_context,
+                         frombuffer)
 
 try:
     import cloudpickle
@@ -102,6 +103,31 @@ def _register_custom_pandas_handlers(context):
         custom_deserializer=_deserialize_pandas_dataframe)
 
 
+def register_torch_serialization_handlers(serialization_context):
+    # ----------------------------------------------------------------------
+    # Set up serialization for pytorch tensors
+
+    try:
+        import torch
+
+        def _serialize_torch_tensor(obj):
+            return obj.numpy()
+
+        def _deserialize_torch_tensor(data):
+            return torch.from_numpy(data)
+
+        for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
+                  torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
+                  torch.IntTensor, torch.LongTensor]:
+            serialization_context.register_type(
+                t, "torch." + t.__name__,
+                custom_serializer=_serialize_torch_tensor,
+                custom_deserializer=_deserialize_torch_tensor)
+    except ImportError:
+        # no torch
+        pass
+
+
 def register_default_serialization_handlers(serialization_context):
 
     # ----------------------------------------------------------------------
@@ -154,37 +180,21 @@ def 
register_default_serialization_handlers(serialization_context):
         custom_serializer=_serialize_numpy_array_list,
         custom_deserializer=_deserialize_numpy_array_list)
 
-    # ----------------------------------------------------------------------
-    # Set up serialization for pytorch tensors
-
-    try:
-        import torch
-
-        def _serialize_torch_tensor(obj):
-            return obj.numpy()
+    _register_custom_pandas_handlers(serialization_context)
 
-        def _deserialize_torch_tensor(data):
-            return torch.from_numpy(data)
 
-        for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
-                  torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
-                  torch.IntTensor, torch.LongTensor]:
-            serialization_context.register_type(
-                t, "torch." + t.__name__,
-                custom_serializer=_serialize_torch_tensor,
-                custom_deserializer=_deserialize_torch_tensor)
-    except ImportError:
-        # no torch
-        pass
-
-    _register_custom_pandas_handlers(serialization_context)
+def default_serialization_context():
+    context = SerializationContext()
+    register_default_serialization_handlers(context)
+    return context
 
 
 register_default_serialization_handlers(_default_serialization_context)
 
-pandas_serialization_context = _default_serialization_context.clone()
 
-pandas_serialization_context.register_type(
-    np.ndarray, 'np.array',
-    custom_serializer=_serialize_numpy_array_pickle,
-    custom_deserializer=_deserialize_numpy_array_pickle)
+def pandas_serialization_context():
+    context = default_serialization_context()
+    context.register_type(np.ndarray, 'np.array',
+                          custom_serializer=_serialize_numpy_array_pickle,
+                          custom_deserializer=_deserialize_numpy_array_pickle)
+    return context
diff --git a/python/pyarrow/tests/test_convert_pandas.py 
b/python/pyarrow/tests/test_convert_pandas.py
index fa265e5..ca2f1e3 100644
--- a/python/pyarrow/tests/test_convert_pandas.py
+++ b/python/pyarrow/tests/test_convert_pandas.py
@@ -1404,7 +1404,7 @@ def _fully_loaded_dataframe_example():
 
 
 def _check_serialize_components_roundtrip(df):
-    ctx = pa.pandas_serialization_context
+    ctx = pa.pandas_serialization_context()
 
     components = ctx.serialize(df).to_components()
     deserialized = ctx.deserialize_components(components)
diff --git a/python/pyarrow/tests/test_serialization.py 
b/python/pyarrow/tests/test_serialization.py
index 284c7fb..7a42010 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -190,8 +190,7 @@ CUSTOM_OBJECTS = [Exception("Test object."), CustomError(), 
Point(11, y=22),
 
 
 def make_serialization_context():
-
-    context = pa._default_serialization_context
+    context = pa.default_serialization_context()
 
     context.register_type(Foo, "Foo")
     context.register_type(Bar, "Bar")
@@ -207,26 +206,27 @@ def make_serialization_context():
     return context
 
 
-serialization_context = make_serialization_context()
+global_serialization_context = make_serialization_context()
 
 
-def serialization_roundtrip(value, scratch_buffer, ctx=serialization_context):
+def serialization_roundtrip(value, scratch_buffer,
+                            context=global_serialization_context):
     writer = pa.FixedSizeBufferWriter(scratch_buffer)
-    pa.serialize_to(value, writer, ctx)
+    pa.serialize_to(value, writer, context=context)
 
     reader = pa.BufferReader(scratch_buffer)
-    result = pa.deserialize_from(reader, None, ctx)
+    result = pa.deserialize_from(reader, None, context=context)
     assert_equal(value, result)
 
-    _check_component_roundtrip(value)
+    _check_component_roundtrip(value, context=context)
 
 
-def _check_component_roundtrip(value):
+def _check_component_roundtrip(value, context=global_serialization_context):
     # Test to/from components
-    serialized = pa.serialize(value)
+    serialized = pa.serialize(value, context=context)
     components = serialized.to_components()
     from_comp = pa.SerializedPyObject.from_components(components)
-    recons = from_comp.deserialize()
+    recons = from_comp.deserialize(context=context)
     assert_equal(value, recons)
 
 
@@ -252,7 +252,7 @@ def test_primitive_serialization(large_buffer):
     for obj in PRIMITIVE_OBJECTS:
         serialization_roundtrip(obj, large_buffer)
         serialization_roundtrip(obj, large_buffer,
-                                pa.pandas_serialization_context)
+                                pa.pandas_serialization_context())
 
 
 def test_serialize_to_buffer():
@@ -318,22 +318,26 @@ def test_datetime_serialization(large_buffer):
 def test_torch_serialization(large_buffer):
     pytest.importorskip("torch")
     import torch
+
+    serialization_context = pa.default_serialization_context()
+    pa.register_torch_serialization_handlers(serialization_context)
     # These are the only types that are supported for the
     # PyTorch to NumPy conversion
     for t in ["float32", "float64",
               "uint8", "int16", "int32", "int64"]:
         obj = torch.from_numpy(np.random.randn(1000).astype(t))
-        serialization_roundtrip(obj, large_buffer)
+        serialization_roundtrip(obj, large_buffer,
+                                context=serialization_context)
 
 
 def test_numpy_immutable(large_buffer):
     obj = np.zeros([10])
 
     writer = pa.FixedSizeBufferWriter(large_buffer)
-    pa.serialize_to(obj, writer, serialization_context)
+    pa.serialize_to(obj, writer, global_serialization_context)
 
     reader = pa.BufferReader(large_buffer)
-    result = pa.deserialize_from(reader, None, serialization_context)
+    result = pa.deserialize_from(reader, None, global_serialization_context)
     with pytest.raises(ValueError):
         result[0] = 1.0
 
@@ -351,12 +355,12 @@ def test_serialization_callback_numpy():
     def deserialize_dummy_class(serialized_obj):
         return serialized_obj
 
-    pa._default_serialization_context.register_type(
-        DummyClass, "DummyClass",
-        custom_serializer=serialize_dummy_class,
-        custom_deserializer=deserialize_dummy_class)
+    context = pa.default_serialization_context()
+    context.register_type(DummyClass, "DummyClass",
+                          custom_serializer=serialize_dummy_class,
+                          custom_deserializer=deserialize_dummy_class)
 
-    pa.serialize(DummyClass())
+    pa.serialize(DummyClass(), context=context)
 
 
 def test_buffer_serialization():
@@ -370,13 +374,14 @@ def test_buffer_serialization():
     def deserialize_buffer_class(serialized_obj):
         return serialized_obj
 
-    pa._default_serialization_context.register_type(
+    context = pa.default_serialization_context()
+    context.register_type(
         BufferClass, "BufferClass",
         custom_serializer=serialize_buffer_class,
         custom_deserializer=deserialize_buffer_class)
 
-    b = pa.serialize(BufferClass()).to_buffer()
-    assert pa.deserialize(b).to_pybytes() == b"hello"
+    b = pa.serialize(BufferClass(), context=context).to_buffer()
+    assert pa.deserialize(b, context=context).to_pybytes() == b"hello"
 
 
 @pytest.mark.skip(reason="extensive memory requirements")
@@ -485,15 +490,16 @@ def test_serialize_subclasses():
     # with register_type will result in faster and more memory
     # efficient serialization.
 
-    serialization_context.register_type(
+    context = pa.default_serialization_context()
+    context.register_type(
         Serializable, "Serializable",
         custom_serializer=serialize_serializable,
         custom_deserializer=deserialize_serializable)
 
     a = SerializableClass()
-    serialized = pa.serialize(a)
+    serialized = pa.serialize(a, context=context)
 
-    deserialized = serialized.deserialize()
+    deserialized = serialized.deserialize(context=context)
     assert type(deserialized).__name__ == SerializableClass.__name__
     assert deserialized.value == 3
 

-- 
To stop receiving notification emails like this one, please contact
[email protected].

Reply via email to