This is an automated email from the ASF dual-hosted git repository. kszucs pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new 7de93af ARROW-9394: [Python] Support pickling of Scalars 7de93af is described below commit 7de93af56562ec81a8e0f29446b7cb70458a518a Author: Krisztián Szűcs <szucs.kriszt...@gmail.com> AuthorDate: Wed Jul 29 13:14:41 2020 +0200 ARROW-9394: [Python] Support pickling of Scalars Since there are no sequence converters available for Dictionary and Union types we cannot construct them directly thus `pa.scalar` fail as the reducer function to reconstruct them. We can add custom reducers for them later, so I'm leaving them as NotImplemented for now. Closes #7852 from kszucs/ARROW-9394 Authored-by: Krisztián Szűcs <szucs.kriszt...@gmail.com> Signed-off-by: Krisztián Szűcs <szucs.kriszt...@gmail.com> --- python/pyarrow/scalar.pxi | 3 ++ python/pyarrow/tests/test_scalars.py | 73 +++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/python/pyarrow/scalar.pxi b/python/pyarrow/scalar.pxi index f607070..7f35419 100644 --- a/python/pyarrow/scalar.pxi +++ b/python/pyarrow/scalar.pxi @@ -97,6 +97,9 @@ cdef class Scalar(_Weakrefable): cdef CScalarHash hasher return hasher(self.wrapped) + def __reduce__(self): + return scalar, (self.as_py(), self.type) + def as_py(self): raise NotImplementedError() diff --git a/python/pyarrow/tests/test_scalars.py b/python/pyarrow/tests/test_scalars.py index 3a19a1c..091ae38 100644 --- a/python/pyarrow/tests/test_scalars.py +++ b/python/pyarrow/tests/test_scalars.py @@ -17,6 +17,7 @@ import datetime import decimal +import pickle import pytest import weakref @@ -41,16 +42,29 @@ import pyarrow as pa (1.0, None, pa.DoubleScalar, pa.DoubleValue), (np.float16(1.0), pa.float16(), pa.HalfFloatScalar, pa.HalfFloatValue), (1.0, pa.float32(), pa.FloatScalar, pa.FloatValue), + (decimal.Decimal("1.123"), None, pa.Decimal128Scalar, pa.Decimal128Value), ("string", None, pa.StringScalar, pa.StringValue), (b"bytes", None, pa.BinaryScalar, pa.BinaryValue), + ("largestring", pa.large_string(), pa.LargeStringScalar, + pa.LargeStringValue), + (b"largebytes", pa.large_binary(), pa.LargeBinaryScalar, + pa.LargeBinaryValue), + (b"abc", pa.binary(3), pa.FixedSizeBinaryScalar, pa.FixedSizeBinaryValue), ([1, 2, 3], None, pa.ListScalar, pa.ListValue), ([1, 2, 3, 4], pa.large_list(pa.int8()), pa.LargeListScalar, pa.LargeListValue), + ([1, 2, 3, 4, 5], pa.list_(pa.int8(), 5), pa.FixedSizeListScalar, + pa.FixedSizeListValue), (datetime.date.today(), None, pa.Date32Scalar, pa.Date32Value), + (datetime.date.today(), pa.date64(), pa.Date64Scalar, pa.Date64Value), (datetime.datetime.now(), None, pa.TimestampScalar, pa.TimestampValue), + (datetime.datetime.now().time().replace(microsecond=0), pa.time32('s'), + pa.Time32Scalar, pa.Time32Value), (datetime.datetime.now().time(), None, pa.Time64Scalar, pa.Time64Value), (datetime.timedelta(days=1), None, pa.DurationScalar, pa.DurationValue), - ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue) + ({'a': 1, 'b': [1, 2]}, None, pa.StructScalar, pa.StructValue), + ([('a', 1), ('b', 2)], pa.map_(pa.string(), pa.int8()), pa.MapScalar, + pa.MapValue), ]) def test_basics(value, ty, klass, deprecated): s = pa.scalar(value, type=ty) @@ -69,6 +83,11 @@ def test_basics(value, ty, klass, deprecated): assert s.as_py() is None assert s != pa.scalar(value, type=ty) + # test pickle roundtrip + restored = pickle.loads(pickle.dumps(s)) + assert s.equals(restored) + + # test that scalars are weak-referenceable wr = weakref.ref(s) assert wr() is not None del s @@ -95,6 +114,11 @@ def test_nulls(): assert v is pa.NA assert v.as_py() is None + # test pickle roundtrip + restored = pickle.loads(pickle.dumps(null)) + assert restored.equals(null) + + # test that scalars are weak-referenceable wr = weakref.ref(null) assert wr() is not None del null @@ -502,6 +526,9 @@ def test_map(): with pytest.raises(IndexError): s[2] + restored = pickle.loads(pickle.dumps(s)) + assert restored.equals(s) + def test_dictionary(): indices = [2, 1, 2, 0] @@ -522,3 +549,47 @@ def test_dictionary(): assert s.index_value.as_py() == i with pytest.warns(FutureWarning): assert s.dictionary_value.as_py() == v + + with pytest.raises(pa.ArrowNotImplementedError): + pickle.loads(pickle.dumps(s)) + + +def test_union(): + # sparse + arr = pa.UnionArray.from_sparse( + pa.array([0, 0, 1, 1], type=pa.int8()), + [ + pa.array(["a", "b", "c", "d"]), + pa.array([1, 2, 3, 4]) + ] + ) + for s in arr: + assert isinstance(s, pa.UnionScalar) + assert s.type.equals(arr.type) + assert s.is_valid is True + with pytest.raises(pa.ArrowNotImplementedError): + pickle.loads(pickle.dumps(s)) + + assert arr[0].as_py() == "a" + assert arr[1].as_py() == "b" + assert arr[2].as_py() == 3 + assert arr[3].as_py() == 4 + + # dense + arr = pa.UnionArray.from_dense( + types=pa.array([0, 1, 0, 0, 1, 1, 0], type='int8'), + value_offsets=pa.array([0, 0, 2, 1, 1, 2, 3], type='int32'), + children=[ + pa.array([b'a', b'b', b'c', b'd'], type='binary'), + pa.array([1, 2, 3], type='int64') + ] + ) + for s in arr: + assert isinstance(s, pa.UnionScalar) + assert s.type.equals(arr.type) + assert s.is_valid is True + with pytest.raises(pa.ArrowNotImplementedError): + pickle.loads(pickle.dumps(s)) + + assert arr[0].as_py() == b'a' + assert arr[5].as_py() == 3