This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new ff28c76 ARROW-2024: [Python] Remove torch serialization from default
serialization context.
ff28c76 is described below
commit ff28c7647c1910f1a0d1d97b8ba68b2b554e5ce1
Author: Robert Nishihara <[email protected]>
AuthorDate: Thu Feb 1 12:17:46 2018 -0500
ARROW-2024: [Python] Remove torch serialization from default serialization
context.
See discussion in #1223.
Author: Robert Nishihara <[email protected]>
Closes #1538 from robertnishihara/dontimportpytorch and squashes the
following commits:
09a8cfa1 [Robert Nishihara] Fix bug.
8992e0fb [Robert Nishihara] Fix.
83ffb70f [Robert Nishihara] Remove torch serialization from default
serialization context.
---
python/pyarrow/__init__.py | 5 ++-
python/pyarrow/serialization.py | 66 +++++++++++++++++------------
python/pyarrow/tests/test_convert_pandas.py | 2 +-
python/pyarrow/tests/test_serialization.py | 56 +++++++++++++-----------
4 files changed, 73 insertions(+), 56 deletions(-)
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index a245fe6..8b3cba9 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -124,9 +124,10 @@ from pyarrow.ipc import (Message, MessageReader,
localfs = LocalFileSystem.get_instance()
-from pyarrow.serialization import (_default_serialization_context,
+from pyarrow.serialization import (default_serialization_context,
pandas_serialization_context,
- register_default_serialization_handlers)
+ register_default_serialization_handlers,
+ register_torch_serialization_handlers)
import pyarrow.types as types
diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py
index 61f2e83..c8b72b7 100644
--- a/python/pyarrow/serialization.py
+++ b/python/pyarrow/serialization.py
@@ -22,7 +22,8 @@ import sys
import numpy as np
from pyarrow.compat import builtin_pickle
-from pyarrow.lib import _default_serialization_context, frombuffer
+from pyarrow.lib import (SerializationContext, _default_serialization_context,
+ frombuffer)
try:
import cloudpickle
@@ -102,6 +103,31 @@ def _register_custom_pandas_handlers(context):
custom_deserializer=_deserialize_pandas_dataframe)
+def register_torch_serialization_handlers(serialization_context):
+ # ----------------------------------------------------------------------
+ # Set up serialization for pytorch tensors
+
+ try:
+ import torch
+
+ def _serialize_torch_tensor(obj):
+ return obj.numpy()
+
+ def _deserialize_torch_tensor(data):
+ return torch.from_numpy(data)
+
+ for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
+ torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
+ torch.IntTensor, torch.LongTensor]:
+ serialization_context.register_type(
+ t, "torch." + t.__name__,
+ custom_serializer=_serialize_torch_tensor,
+ custom_deserializer=_deserialize_torch_tensor)
+ except ImportError:
+ # no torch
+ pass
+
+
def register_default_serialization_handlers(serialization_context):
# ----------------------------------------------------------------------
@@ -154,37 +180,21 @@ def
register_default_serialization_handlers(serialization_context):
custom_serializer=_serialize_numpy_array_list,
custom_deserializer=_deserialize_numpy_array_list)
- # ----------------------------------------------------------------------
- # Set up serialization for pytorch tensors
-
- try:
- import torch
-
- def _serialize_torch_tensor(obj):
- return obj.numpy()
+ _register_custom_pandas_handlers(serialization_context)
- def _deserialize_torch_tensor(data):
- return torch.from_numpy(data)
- for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor,
- torch.ByteTensor, torch.CharTensor, torch.ShortTensor,
- torch.IntTensor, torch.LongTensor]:
- serialization_context.register_type(
- t, "torch." + t.__name__,
- custom_serializer=_serialize_torch_tensor,
- custom_deserializer=_deserialize_torch_tensor)
- except ImportError:
- # no torch
- pass
-
- _register_custom_pandas_handlers(serialization_context)
+def default_serialization_context():
+ context = SerializationContext()
+ register_default_serialization_handlers(context)
+ return context
register_default_serialization_handlers(_default_serialization_context)
-pandas_serialization_context = _default_serialization_context.clone()
-pandas_serialization_context.register_type(
- np.ndarray, 'np.array',
- custom_serializer=_serialize_numpy_array_pickle,
- custom_deserializer=_deserialize_numpy_array_pickle)
+def pandas_serialization_context():
+ context = default_serialization_context()
+ context.register_type(np.ndarray, 'np.array',
+ custom_serializer=_serialize_numpy_array_pickle,
+ custom_deserializer=_deserialize_numpy_array_pickle)
+ return context
diff --git a/python/pyarrow/tests/test_convert_pandas.py
b/python/pyarrow/tests/test_convert_pandas.py
index fa265e5..ca2f1e3 100644
--- a/python/pyarrow/tests/test_convert_pandas.py
+++ b/python/pyarrow/tests/test_convert_pandas.py
@@ -1404,7 +1404,7 @@ def _fully_loaded_dataframe_example():
def _check_serialize_components_roundtrip(df):
- ctx = pa.pandas_serialization_context
+ ctx = pa.pandas_serialization_context()
components = ctx.serialize(df).to_components()
deserialized = ctx.deserialize_components(components)
diff --git a/python/pyarrow/tests/test_serialization.py
b/python/pyarrow/tests/test_serialization.py
index 284c7fb..7a42010 100644
--- a/python/pyarrow/tests/test_serialization.py
+++ b/python/pyarrow/tests/test_serialization.py
@@ -190,8 +190,7 @@ CUSTOM_OBJECTS = [Exception("Test object."), CustomError(),
Point(11, y=22),
def make_serialization_context():
-
- context = pa._default_serialization_context
+ context = pa.default_serialization_context()
context.register_type(Foo, "Foo")
context.register_type(Bar, "Bar")
@@ -207,26 +206,27 @@ def make_serialization_context():
return context
-serialization_context = make_serialization_context()
+global_serialization_context = make_serialization_context()
-def serialization_roundtrip(value, scratch_buffer, ctx=serialization_context):
+def serialization_roundtrip(value, scratch_buffer,
+ context=global_serialization_context):
writer = pa.FixedSizeBufferWriter(scratch_buffer)
- pa.serialize_to(value, writer, ctx)
+ pa.serialize_to(value, writer, context=context)
reader = pa.BufferReader(scratch_buffer)
- result = pa.deserialize_from(reader, None, ctx)
+ result = pa.deserialize_from(reader, None, context=context)
assert_equal(value, result)
- _check_component_roundtrip(value)
+ _check_component_roundtrip(value, context=context)
-def _check_component_roundtrip(value):
+def _check_component_roundtrip(value, context=global_serialization_context):
# Test to/from components
- serialized = pa.serialize(value)
+ serialized = pa.serialize(value, context=context)
components = serialized.to_components()
from_comp = pa.SerializedPyObject.from_components(components)
- recons = from_comp.deserialize()
+ recons = from_comp.deserialize(context=context)
assert_equal(value, recons)
@@ -252,7 +252,7 @@ def test_primitive_serialization(large_buffer):
for obj in PRIMITIVE_OBJECTS:
serialization_roundtrip(obj, large_buffer)
serialization_roundtrip(obj, large_buffer,
- pa.pandas_serialization_context)
+ pa.pandas_serialization_context())
def test_serialize_to_buffer():
@@ -318,22 +318,26 @@ def test_datetime_serialization(large_buffer):
def test_torch_serialization(large_buffer):
pytest.importorskip("torch")
import torch
+
+ serialization_context = pa.default_serialization_context()
+ pa.register_torch_serialization_handlers(serialization_context)
# These are the only types that are supported for the
# PyTorch to NumPy conversion
for t in ["float32", "float64",
"uint8", "int16", "int32", "int64"]:
obj = torch.from_numpy(np.random.randn(1000).astype(t))
- serialization_roundtrip(obj, large_buffer)
+ serialization_roundtrip(obj, large_buffer,
+ context=serialization_context)
def test_numpy_immutable(large_buffer):
obj = np.zeros([10])
writer = pa.FixedSizeBufferWriter(large_buffer)
- pa.serialize_to(obj, writer, serialization_context)
+ pa.serialize_to(obj, writer, global_serialization_context)
reader = pa.BufferReader(large_buffer)
- result = pa.deserialize_from(reader, None, serialization_context)
+ result = pa.deserialize_from(reader, None, global_serialization_context)
with pytest.raises(ValueError):
result[0] = 1.0
@@ -351,12 +355,12 @@ def test_serialization_callback_numpy():
def deserialize_dummy_class(serialized_obj):
return serialized_obj
- pa._default_serialization_context.register_type(
- DummyClass, "DummyClass",
- custom_serializer=serialize_dummy_class,
- custom_deserializer=deserialize_dummy_class)
+ context = pa.default_serialization_context()
+ context.register_type(DummyClass, "DummyClass",
+ custom_serializer=serialize_dummy_class,
+ custom_deserializer=deserialize_dummy_class)
- pa.serialize(DummyClass())
+ pa.serialize(DummyClass(), context=context)
def test_buffer_serialization():
@@ -370,13 +374,14 @@ def test_buffer_serialization():
def deserialize_buffer_class(serialized_obj):
return serialized_obj
- pa._default_serialization_context.register_type(
+ context = pa.default_serialization_context()
+ context.register_type(
BufferClass, "BufferClass",
custom_serializer=serialize_buffer_class,
custom_deserializer=deserialize_buffer_class)
- b = pa.serialize(BufferClass()).to_buffer()
- assert pa.deserialize(b).to_pybytes() == b"hello"
+ b = pa.serialize(BufferClass(), context=context).to_buffer()
+ assert pa.deserialize(b, context=context).to_pybytes() == b"hello"
@pytest.mark.skip(reason="extensive memory requirements")
@@ -485,15 +490,16 @@ def test_serialize_subclasses():
# with register_type will result in faster and more memory
# efficient serialization.
- serialization_context.register_type(
+ context = pa.default_serialization_context()
+ context.register_type(
Serializable, "Serializable",
custom_serializer=serialize_serializable,
custom_deserializer=deserialize_serializable)
a = SerializableClass()
- serialized = pa.serialize(a)
+ serialized = pa.serialize(a, context=context)
- deserialized = serialized.deserialize()
+ deserialized = serialized.deserialize(context=context)
assert type(deserialized).__name__ == SerializableClass.__name__
assert deserialized.value == 3
--
To stop receiving notification emails like this one, please contact
[email protected].