This is an automated email from the ASF dual-hosted git repository. wesm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push: new eb5dd50 ARROW-840: [Python] Expose extension types eb5dd50 is described below commit eb5dd508ee3f592bf1c2a04cce09ee95e137e89b Author: Antoine Pitrou <anto...@python.org> AuthorDate: Fri Jun 14 07:53:40 2019 -0500 ARROW-840: [Python] Expose extension types Add infrastructure to consume C++ extension types and extension arrays from Python. Also allow creating Python-specific extension types by subclassing `ExtensionType`, and creating extension arrays by passing the type and storage array to `ExtensionArray.from_storage`. Author: Antoine Pitrou <anto...@python.org> Closes #4532 from pitrou/ARROW-840-py-ext-types and squashes the following commits: 95ca6148e <Antoine Pitrou> Add IPC tests 44ac0a156 <Antoine Pitrou> ARROW-840: Expose extension types --- cpp/src/arrow/array.cc | 11 +- cpp/src/arrow/extension_type.cc | 18 +++ cpp/src/arrow/extension_type.h | 9 +- cpp/src/arrow/python/CMakeLists.txt | 1 + cpp/src/arrow/python/extension_type.cc | 196 +++++++++++++++++++++++++ cpp/src/arrow/python/extension_type.h | 77 ++++++++++ cpp/src/arrow/python/pyarrow.h | 1 + python/pyarrow/__init__.py | 4 +- python/pyarrow/array.pxi | 42 +++++- python/pyarrow/includes/libarrow.pxd | 32 ++++ python/pyarrow/lib.pxd | 20 ++- python/pyarrow/public-api.pxi | 12 +- python/pyarrow/tests/test_extension_type.py | 219 ++++++++++++++++++++++++++++ python/pyarrow/types.pxi | 150 +++++++++++++++++-- 14 files changed, 775 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 7a3d36e..9d37b45 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -1259,7 +1259,16 @@ struct ValidateVisitor { return Status::OK(); } - Status Visit(const ExtensionArray& array) { return ValidateArray(*array.storage()); } + Status Visit(const ExtensionArray& array) { + const auto& ext_type = checked_cast<const ExtensionType&>(*array.type()); + + if (!array.storage()->type()->Equals(*ext_type.storage_type())) { + return Status::Invalid("Extension array of type '", array.type()->ToString(), + "' has storage array of incompatible type '", + array.storage()->type()->ToString(), "'"); + } + return ValidateArray(*array.storage()); + } protected: template <typename ArrayType> diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc index e104c03..25945f3 100644 --- a/cpp/src/arrow/extension_type.cc +++ b/cpp/src/arrow/extension_type.cc @@ -27,10 +27,14 @@ #include "arrow/array.h" #include "arrow/status.h" #include "arrow/type.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" #include "arrow/util/visibility.h" namespace arrow { +using internal::checked_cast; + DataTypeLayout ExtensionType::layout() const { return storage_type_->layout(); } std::string ExtensionType::ToString() const { @@ -41,7 +45,21 @@ std::string ExtensionType::ToString() const { std::string ExtensionType::name() const { return "extension"; } +ExtensionArray::ExtensionArray(const std::shared_ptr<ArrayData>& data) { SetData(data); } + +ExtensionArray::ExtensionArray(const std::shared_ptr<DataType>& type, + const std::shared_ptr<Array>& storage) { + DCHECK_EQ(type->id(), Type::EXTENSION); + DCHECK( + storage->type()->Equals(*checked_cast<const ExtensionType&>(*type).storage_type())); + auto data = storage->data()->Copy(); + // XXX This pointer is reverted below in SetData()... + data->type = type; + SetData(data); +} + void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) { + DCHECK_EQ(data->type->id(), Type::EXTENSION); this->Array::SetData(data); auto storage_data = data->Copy(); diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h index b3df2b3..6a1ca0b 100644 --- a/cpp/src/arrow/extension_type.h +++ b/cpp/src/arrow/extension_type.h @@ -84,7 +84,14 @@ class ARROW_EXPORT ExtensionType : public DataType { /// \brief Base array class for user-defined extension types class ARROW_EXPORT ExtensionArray : public Array { public: - explicit ExtensionArray(const std::shared_ptr<ArrayData>& data) { SetData(data); } + /// \brief Construct an ExtensionArray from an ArrayData. + /// + /// The ArrayData must have the right ExtensionType. + explicit ExtensionArray(const std::shared_ptr<ArrayData>& data); + + /// \brief Construct an ExtensionArray from a type and the underlying storage. + ExtensionArray(const std::shared_ptr<DataType>& type, + const std::shared_ptr<Array>& storage); /// \brief The physical storage for the extension array std::shared_ptr<Array> storage() const { return storage_; } diff --git a/cpp/src/arrow/python/CMakeLists.txt b/cpp/src/arrow/python/CMakeLists.txt index d6376f5..0d17a9f 100644 --- a/cpp/src/arrow/python/CMakeLists.txt +++ b/cpp/src/arrow/python/CMakeLists.txt @@ -34,6 +34,7 @@ set(ARROW_PYTHON_SRCS config.cc decimal.cc deserialize.cc + extension_type.cc helpers.cc inference.cc init.cc diff --git a/cpp/src/arrow/python/extension_type.cc b/cpp/src/arrow/python/extension_type.cc new file mode 100644 index 0000000..b130030 --- /dev/null +++ b/cpp/src/arrow/python/extension_type.cc @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <memory> + +#include "arrow/python/extension_type.h" +#include "arrow/python/helpers.h" +#include "arrow/python/pyarrow.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; + +namespace py { + +namespace { + +// Serialize a Python ExtensionType instance +Status SerializeExtInstance(PyObject* type_instance, std::string* out) { + OwnedRef res(PyObject_CallMethod(type_instance, "__arrow_ext_serialize__", nullptr)); + if (!res) { + return ConvertPyError(); + } + if (!PyBytes_Check(res.obj())) { + return Status::TypeError( + "__arrow_ext_serialize__ should return bytes object, " + "got ", + internal::PyObject_StdStringRepr(res.obj())); + } + *out = internal::PyBytes_AsStdString(res.obj()); + return Status::OK(); +} + +// Deserialize a Python ExtensionType instance +PyObject* DeserializeExtInstance(PyObject* type_class, + std::shared_ptr<DataType> storage_type, + const std::string& serialized_data) { + OwnedRef storage_ref(wrap_data_type(storage_type)); + if (!storage_ref) { + return nullptr; + } + OwnedRef data_ref(PyBytes_FromStringAndSize( + serialized_data.data(), static_cast<Py_ssize_t>(serialized_data.size()))); + if (!data_ref) { + return nullptr; + } + + return PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO", + storage_ref.obj(), data_ref.obj()); +} + +} // namespace + +static const char* kExtensionName = "arrow.py_extension_type"; + +PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ, + PyObject* inst) + : ExtensionType(storage_type), type_class_(typ), type_instance_(inst) {} + +std::string PyExtensionType::extension_name() const { return kExtensionName; } + +bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const { + PyAcquireGIL lock; + + if (other.extension_name() != extension_name()) { + return false; + } + const auto& other_ext = checked_cast<const PyExtensionType&>(other); + int res = -1; + if (!type_instance_) { + if (other_ext.type_instance_) { + return false; + } + // Compare Python types + res = PyObject_RichCompareBool(type_class_.obj(), other_ext.type_class_.obj(), Py_EQ); + } else { + if (!other_ext.type_instance_) { + return false; + } + // Compare Python instances + OwnedRef left(GetInstance()); + OwnedRef right(other_ext.GetInstance()); + if (!left || !right) { + goto error; + } + res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ); + } + if (res == -1) { + goto error; + } + return res == 1; + +error: + // Cannot propagate error + PyErr_WriteUnraisable(nullptr); + return false; +} + +std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData> data) const { + DCHECK_EQ(data->type->id(), Type::EXTENSION); + DCHECK_EQ(kExtensionName, + checked_cast<const ExtensionType&>(*data->type).extension_name()); + return std::make_shared<ExtensionArray>(data); +} + +std::string PyExtensionType::Serialize() const { + DCHECK(type_instance_); + return serialized_; +} + +Status PyExtensionType::Deserialize(std::shared_ptr<DataType> storage_type, + const std::string& serialized_data, + std::shared_ptr<DataType>* out) const { + PyAcquireGIL lock; + + if (import_pyarrow()) { + return ConvertPyError(); + } + OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type, serialized_data)); + if (!res) { + return ConvertPyError(); + } + return unwrap_data_type(res.obj(), out); +} + +PyObject* PyExtensionType::GetInstance() const { + if (!type_instance_) { + PyErr_SetString(PyExc_TypeError, "Not an instance"); + return nullptr; + } + DCHECK(PyWeakref_CheckRef(type_instance_.obj())); + PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj()); + if (inst != Py_None) { + // Cached instance still alive + Py_INCREF(inst); + return inst; + } else { + // Must reconstruct from serialized form + // XXX cache again? + return DeserializeExtInstance(type_class_.obj(), storage_type_, serialized_); + } +} + +Status PyExtensionType::SetInstance(PyObject* inst) const { + // Check we have the right type + PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst)); + if (typ != type_class_.obj()) { + return Status::TypeError("Unexpected Python ExtensionType class ", + internal::PyObject_StdStringRepr(typ), " expected ", + internal::PyObject_StdStringRepr(type_class_.obj())); + } + + PyObject* wr = PyWeakref_NewRef(inst, nullptr); + if (wr == NULL) { + return ConvertPyError(); + } + type_instance_.reset(wr); + return SerializeExtInstance(inst, &serialized_); +} + +Status PyExtensionType::FromClass(std::shared_ptr<DataType> storage_type, PyObject* typ, + std::shared_ptr<ExtensionType>* out) { + Py_INCREF(typ); + out->reset(new PyExtensionType(storage_type, typ)); + return Status::OK(); +} + +Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) { + DCHECK_EQ(type->id(), Type::EXTENSION); + auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type); + DCHECK_EQ(ext_type->extension_name(), kExtensionName); + return RegisterExtensionType(ext_type); +} + +Status UnregisterPyExtensionType() { return UnregisterExtensionType(kExtensionName); } + +std::string PyExtensionName() { return kExtensionName; } + +} // namespace py +} // namespace arrow diff --git a/cpp/src/arrow/python/extension_type.h b/cpp/src/arrow/python/extension_type.h new file mode 100644 index 0000000..12f9108 --- /dev/null +++ b/cpp/src/arrow/python/extension_type.h @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include <memory> +#include <string> + +#include "arrow/extension_type.h" +#include "arrow/python/common.h" +#include "arrow/python/visibility.h" +#include "arrow/util/macros.h" + +namespace arrow { +namespace py { + +class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType { + public: + // Implement extensionType API + std::string extension_name() const override; + + bool ExtensionEquals(const ExtensionType& other) const override; + + std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override; + + Status Deserialize(std::shared_ptr<DataType> storage_type, + const std::string& serialized_data, + std::shared_ptr<DataType>* out) const override; + + std::string Serialize() const override; + + // For use from Cython + static Status FromClass(std::shared_ptr<DataType> storage_type, PyObject* typ, + std::shared_ptr<ExtensionType>* out); + + // Return new ref + PyObject* GetInstance() const; + Status SetInstance(PyObject*) const; + + protected: + PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ, + PyObject* inst = NULLPTR); + + // These fields are mutable because of two-step initialization. + mutable OwnedRefNoGIL type_class_; + // A weakref or null. Storing a strong reference to the Python extension type + // instance would create an unreclaimable reference cycle between Python and C++ + // (the Python instance has to keep a strong reference to the C++ ExtensionType + // in other direction). Instead, we store a weakref to the instance. + // If the weakref is dead, we reconstruct the instance from its serialized form. + mutable OwnedRefNoGIL type_instance_; + // Empty if type_instance_ is null + mutable std::string serialized_; +}; + +ARROW_PYTHON_EXPORT std::string PyExtensionName(); + +ARROW_PYTHON_EXPORT Status RegisterPyExtensionType(const std::shared_ptr<DataType>&); + +ARROW_PYTHON_EXPORT Status UnregisterPyExtensionType(); + +} // namespace py +} // namespace arrow diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h index a5a3910..5e42333 100644 --- a/cpp/src/arrow/python/pyarrow.h +++ b/cpp/src/arrow/python/pyarrow.h @@ -39,6 +39,7 @@ class Tensor; namespace py { +// Returns 0 on success, -1 on error. ARROW_PYTHON_EXPORT int import_pyarrow(); ARROW_PYTHON_EXPORT bool is_buffer(PyObject* buffer); diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index f9ba819..556b87d 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -58,6 +58,8 @@ from pyarrow.lib import (null, bool_, DataType, DictionaryType, ListType, StructType, UnionType, TimestampType, Time32Type, Time64Type, FixedSizeBinaryType, Decimal128Type, + BaseExtensionType, ExtensionType, + UnknownExtensionType, DictionaryMemo, Field, Schema, @@ -78,7 +80,7 @@ from pyarrow.lib import (null, bool_, DictionaryArray, Date32Array, Date64Array, TimestampArray, Time32Array, Time64Array, - Decimal128Array, StructArray, + Decimal128Array, StructArray, ExtensionArray, ArrayValue, Scalar, NA, _NULL as NULL, BooleanValue, Int8Value, Int16Value, Int32Value, Int64Value, diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index cce967e..607d7ae 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -415,7 +415,7 @@ cdef class Array(_PandasConvertible): "the `pyarrow.Array.from_*` functions instead." .format(self.__class__.__name__)) - cdef void init(self, const shared_ptr[CArray]& sp_array): + cdef void init(self, const shared_ptr[CArray]& sp_array) except *: self.sp_array = sp_array self.ap = sp_array.get() self.type = pyarrow_wrap_data_type(self.sp_array.get().type()) @@ -1458,6 +1458,45 @@ cdef class StructArray(Array): return pyarrow_wrap_array(c_result) +cdef class ExtensionArray(Array): + """ + Concrete class for Arrow extension arrays. + """ + + @property + def storage(self): + cdef: + CExtensionArray* ext_array = <CExtensionArray*>(self.ap) + + return pyarrow_wrap_array(ext_array.storage()) + + @staticmethod + def from_storage(BaseExtensionType typ, Array storage): + """ + Construct ExtensionArray from type and storage array. + + Parameters + ---------- + typ: DataType + The extension type for the result array. + storage: Array + The underlying storage for the result array. + + Returns + ------- + ext_array : ExtensionArray + """ + cdef: + shared_ptr[CExtensionArray] ext_array + + if storage.type != typ.storage_type: + raise TypeError("Incompatible storage type {0} " + "for extension type {1}".format(storage.type, typ)) + + ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array) + return pyarrow_wrap_array(<shared_ptr[CArray]> ext_array) + + cdef dict _array_classes = { _Type_NA: NullArray, _Type_BOOL: BooleanArray, @@ -1485,6 +1524,7 @@ cdef dict _array_classes = { _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray, _Type_DECIMAL: Decimal128Array, _Type_STRUCT: StructArray, + _Type_EXTENSION: ExtensionArray, } diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index f979cd6..178a250 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -73,6 +73,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: _Type_DICTIONARY" arrow::Type::DICTIONARY" _Type_MAP" arrow::Type::MAP" + _Type_EXTENSION" arrow::Type::EXTENSION" + enum UnionMode" arrow::UnionMode::type": _UnionMode_SPARSE" arrow::UnionMode::SPARSE" _UnionMode_DENSE" arrow::UnionMode::DENSE" @@ -1272,6 +1274,36 @@ cdef extern from 'arrow/python/inference.h' namespace 'arrow::py': c_bool IsPyFloat(object o) +cdef extern from 'arrow/extension_type.h' namespace 'arrow': + cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType): + c_string extension_name() + shared_ptr[CDataType] storage_type() + + cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray): + CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage) + + shared_ptr[CArray] storage() + + +cdef extern from 'arrow/python/extension_type.h' namespace 'arrow::py': + cdef cppclass CPyExtensionType \ + " arrow::py::PyExtensionType"(CExtensionType): + @staticmethod + CStatus FromClass(shared_ptr[CDataType] storage_type, + object typ, shared_ptr[CExtensionType]* out) + + @staticmethod + CStatus FromInstance(shared_ptr[CDataType] storage_type, + object inst, shared_ptr[CExtensionType]* out) + + object GetInstance() + CStatus SetInstance(object) + + c_string PyExtensionName() + CStatus RegisterPyExtensionType(shared_ptr[CDataType]) + CStatus UnregisterPyExtensionType() + + cdef extern from 'arrow/python/benchmark.h' namespace 'arrow::py::benchmark': void Benchmark_PandasObjectIsNull(object lst) except * diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 998848d..79ab947 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -53,8 +53,9 @@ cdef class DataType: shared_ptr[CDataType] sp_type CDataType* type bytes pep3118_format + object __weakref__ - cdef void init(self, const shared_ptr[CDataType]& type) + cdef void init(self, const shared_ptr[CDataType]& type) except * cdef Field child(self, int i) @@ -106,6 +107,16 @@ cdef class Decimal128Type(FixedSizeBinaryType): const CDecimal128Type* decimal128_type +cdef class BaseExtensionType(DataType): + cdef: + const CExtensionType* ext_type + + +cdef class ExtensionType(BaseExtensionType): + cdef: + const CPyExtensionType* cpy_ext_type + + cdef class Field: cdef: shared_ptr[CField] sp_field @@ -199,11 +210,12 @@ cdef class Array(_PandasConvertible): cdef: shared_ptr[CArray] sp_array CArray* ap + object __weakref__ cdef readonly: DataType type - cdef void init(self, const shared_ptr[CArray]& sp_array) + cdef void init(self, const shared_ptr[CArray]& sp_array) except * cdef getitem(self, int64_t i) cdef int64_t length(self) @@ -316,6 +328,10 @@ cdef class DictionaryArray(Array): object _indices, _dictionary +cdef class ExtensionArray(Array): + pass + + cdef wrap_array_output(PyObject* output) cdef object box_scalar(DataType type, const shared_ptr[CArray]& sp_array, diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi index 9392259..33bc803 100644 --- a/python/pyarrow/public-api.pxi +++ b/python/pyarrow/public-api.pxi @@ -66,7 +66,10 @@ cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type( cdef api object pyarrow_wrap_data_type( const shared_ptr[CDataType]& type): - cdef DataType out + cdef: + const CExtensionType* ext_type + const CPyExtensionType* cpy_ext_type + DataType out if type.get() == NULL: return None @@ -85,6 +88,13 @@ cdef api object pyarrow_wrap_data_type( out = FixedSizeBinaryType.__new__(FixedSizeBinaryType) elif type.get().id() == _Type_DECIMAL: out = Decimal128Type.__new__(Decimal128Type) + elif type.get().id() == _Type_EXTENSION: + ext_type = <const CExtensionType*> type.get() + if ext_type.extension_name() == PyExtensionName(): + cpy_ext_type = <const CPyExtensionType*> ext_type + return cpy_ext_type.GetInstance() + else: + out = BaseExtensionType.__new__(BaseExtensionType) else: out = DataType.__new__(DataType) diff --git a/python/pyarrow/tests/test_extension_type.py b/python/pyarrow/tests/test_extension_type.py new file mode 100644 index 0000000..d688d3c --- /dev/null +++ b/python/pyarrow/tests/test_extension_type.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pickle +import weakref + +import pyarrow as pa + +import pytest + + +class UuidType(pa.ExtensionType): + + def __init__(self): + pa.ExtensionType.__init__(self, pa.binary(16)) + + def __reduce__(self): + return UuidType, () + + +class ParamExtType(pa.ExtensionType): + + def __init__(self, width): + self.width = width + pa.ExtensionType.__init__(self, pa.binary(width)) + + def __reduce__(self): + return ParamExtType, (self.width,) + + +def ipc_write_batch(batch): + stream = pa.BufferOutputStream() + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + writer.close() + return stream.getvalue() + + +def ipc_read_batch(buf): + reader = pa.RecordBatchStreamReader(buf) + return reader.read_next_batch() + + +def test_ext_type_basics(): + ty = UuidType() + assert ty.extension_name == "arrow.py_extension_type" + + +def test_ext_type__lifetime(): + ty = UuidType() + wr = weakref.ref(ty) + del ty + assert wr() is None + + +def test_ext_type__storage_type(): + ty = UuidType() + assert ty.storage_type == pa.binary(16) + assert ty.__class__ is UuidType + ty = ParamExtType(5) + assert ty.storage_type == pa.binary(5) + assert ty.__class__ is ParamExtType + + +def test_uuid_type_pickle(): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + ty = UuidType() + ser = pickle.dumps(ty, protocol=proto) + del ty + ty = pickle.loads(ser) + wr = weakref.ref(ty) + assert ty.extension_name == "arrow.py_extension_type" + del ty + assert wr() is None + + +def test_ext_type_equality(): + a = ParamExtType(5) + b = ParamExtType(6) + c = ParamExtType(6) + assert a != b + assert b == c + d = UuidType() + e = UuidType() + assert a != d + assert d == e + + +def test_ext_array_basics(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + arr.validate() + assert arr.type is ty + assert arr.storage.equals(storage) + + +def test_ext_array_lifetime(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + + refs = [weakref.ref(obj) for obj in (ty, arr, storage)] + del ty, storage, arr + for ref in refs: + assert ref() is None + + +def test_ext_array_errors(): + ty = ParamExtType(4) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + with pytest.raises(TypeError, match="Incompatible storage type"): + pa.ExtensionArray.from_storage(ty, storage) + + +def test_ext_array_equality(): + storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16)) + storage3 = pa.array([], type=pa.binary(16)) + ty1 = UuidType() + ty2 = ParamExtType(16) + + a = pa.ExtensionArray.from_storage(ty1, storage1) + b = pa.ExtensionArray.from_storage(ty1, storage2) + assert a.equals(b) + c = pa.ExtensionArray.from_storage(ty1, storage3) + assert not a.equals(c) + d = pa.ExtensionArray.from_storage(ty2, storage1) + assert not a.equals(d) + e = pa.ExtensionArray.from_storage(ty2, storage2) + assert d.equals(e) + f = pa.ExtensionArray.from_storage(ty2, storage3) + assert not d.equals(f) + + +def test_ext_array_pickling(): + for proto in range(0, pickle.HIGHEST_PROTOCOL + 1): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + ser = pickle.dumps(arr, protocol=proto) + del ty, storage, arr + arr = pickle.loads(ser) + arr.validate() + assert isinstance(arr, pa.ExtensionArray) + assert arr.type == ParamExtType(3) + assert arr.type.storage_type == pa.binary(3) + assert arr.storage.type == pa.binary(3) + assert arr.storage.to_pylist() == [b"foo", b"bar"] + + +def example_batch(): + ty = ParamExtType(3) + storage = pa.array([b"foo", b"bar"], type=pa.binary(3)) + arr = pa.ExtensionArray.from_storage(ty, storage) + return pa.RecordBatch.from_arrays([arr], ["exts"]) + + +def check_example_batch(batch): + arr = batch.column(0) + assert isinstance(arr, pa.ExtensionArray) + assert arr.type.storage_type == pa.binary(3) + assert arr.storage.to_pylist() == [b"foo", b"bar"] + return arr + + +def test_ipc(): + batch = example_batch() + buf = ipc_write_batch(batch) + del batch + + batch = ipc_read_batch(buf) + arr = check_example_batch(batch) + assert arr.type == ParamExtType(3) + + +def test_ipc_unknown_type(): + batch = example_batch() + buf = ipc_write_batch(batch) + del batch + + orig_type = ParamExtType + try: + # Simulate the original Python type being unavailable. + # Deserialization should not fail but return a placeholder type. + del globals()['ParamExtType'] + + batch = ipc_read_batch(buf) + arr = check_example_batch(batch) + assert isinstance(arr.type, pa.UnknownExtensionType) + + # Can be serialized again + buf2 = ipc_write_batch(batch) + del batch, arr + + batch = ipc_read_batch(buf2) + arr = check_example_batch(batch) + assert isinstance(arr.type, pa.UnknownExtensionType) + finally: + globals()['ParamExtType'] = orig_type + + # Deserialize again with the type restored + batch = ipc_read_batch(buf2) + arr = check_example_batch(batch) + assert arr.type == ParamExtType(3) diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 9a92761..1f0db4c 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -19,6 +19,7 @@ import re import warnings from pyarrow import compat +from pyarrow.compat import builtin_pickle # These are imprecise because the type (in pandas 0.x) depends on the presence @@ -103,7 +104,7 @@ cdef class DataType: "functions like pyarrow.int64, pyarrow.list_, etc. " "instead.".format(self.__class__.__name__)) - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: self.sp_type = type self.type = type.get() self.pep3118_format = _datatype_to_pep3118(self.type) @@ -203,7 +204,7 @@ cdef class DictionaryType(DataType): Concrete class for dictionary data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.dict_type = <const CDictionaryType*> type.get() @@ -239,7 +240,7 @@ cdef class ListType(DataType): Concrete class for list data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.list_type = <const CListType*> type.get() @@ -259,7 +260,7 @@ cdef class StructType(DataType): Concrete class for struct data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.struct_type = <const CStructType*> type.get() @@ -318,7 +319,7 @@ cdef class UnionType(DataType): Concrete class for struct data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) @property @@ -370,7 +371,7 @@ cdef class TimestampType(DataType): Concrete class for timestamp data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.ts_type = <const CTimestampType*> type.get() @@ -411,7 +412,7 @@ cdef class Time32Type(DataType): Concrete class for time32 data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.time_type = <const CTime32Type*> type.get() @@ -428,7 +429,7 @@ cdef class Time64Type(DataType): Concrete class for time64 data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.time_type = <const CTime64Type*> type.get() @@ -445,7 +446,7 @@ cdef class FixedSizeBinaryType(DataType): Concrete class for fixed-size binary data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: DataType.init(self, type) self.fixed_size_binary_type = ( <const CFixedSizeBinaryType*> type.get()) @@ -466,7 +467,7 @@ cdef class Decimal128Type(FixedSizeBinaryType): Concrete class for decimal128 data types. """ - cdef void init(self, const shared_ptr[CDataType]& type): + cdef void init(self, const shared_ptr[CDataType]& type) except *: FixedSizeBinaryType.init(self, type) self.decimal128_type = <const CDecimal128Type*> type.get() @@ -488,6 +489,132 @@ cdef class Decimal128Type(FixedSizeBinaryType): return self.decimal128_type.scale() +cdef class BaseExtensionType(DataType): + """ + Concrete base class for extension types. + """ + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + DataType.init(self, type) + self.ext_type = <const CExtensionType*> type.get() + + @property + def extension_name(self): + """ + The extension type name. + """ + return frombytes(self.ext_type.extension_name()) + + @property + def storage_type(self): + """ + The underlying storage type. + """ + return pyarrow_wrap_data_type(self.ext_type.storage_type()) + + +cdef class ExtensionType(BaseExtensionType): + """ + Concrete base class for Python-defined extension types. + """ + + def __cinit__(self): + if type(self) is ExtensionType: + raise TypeError("Can only instantiate subclasses of " + "ExtensionType") + + def __init__(self, DataType storage_type): + cdef: + shared_ptr[CExtensionType] cpy_ext_type + + assert storage_type is not None + check_status(CPyExtensionType.FromClass(storage_type.sp_type, + type(self), &cpy_ext_type)) + self.init(<shared_ptr[CDataType]> cpy_ext_type) + + cdef void init(self, const shared_ptr[CDataType]& type) except *: + BaseExtensionType.init(self, type) + self.cpy_ext_type = <const CPyExtensionType*> type.get() + # Store weakref and serialized version of self on C++ type instance + check_status(self.cpy_ext_type.SetInstance(self)) + + def __eq__(self, other): + # Default implementation to avoid infinite recursion through + # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__ + if isinstance(other, ExtensionType): + return (type(self) == type(other) and + self.storage_type == other.storage_type) + else: + return NotImplemented + + def __reduce__(self): + raise NotImplementedError("Please implement {0}.__reduce__" + .format(type(self).__name__)) + + def __arrow_ext_serialize__(self): + return builtin_pickle.dumps(self) + + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + try: + ty = builtin_pickle.loads(serialized) + except Exception: + # For some reason, it's impossible to deserialize the + # ExtensionType instance. Perhaps the serialized data is + # corrupt, or more likely the type is being deserialized + # in an environment where the original Python class or module + # is not available. Fall back on a generic BaseExtensionType. + return UnknownExtensionType(storage_type, serialized) + + if ty.storage_type != storage_type: + raise TypeError("Expected storage type {0} but got {1}" + .format(ty.storage_type, storage_type)) + return ty + + +cdef class UnknownExtensionType(ExtensionType): + """ + A concrete class for Python-defined extension types that refer to + an unknown Python implementation. + """ + + cdef: + bytes serialized + + def __init__(self, DataType storage_type, serialized): + self.serialized = serialized + ExtensionType.__init__(self, storage_type) + + def __arrow_ext_serialize__(self): + return self.serialized + + +cdef class _ExtensionTypesInitializer: + # + # A private object that handles process-wide registration of the Python + # ExtensionType. + # + + def __cinit__(self): + cdef: + DataType storage_type + shared_ptr[CExtensionType] cpy_ext_type + + # Make a dummy C++ ExtensionType + storage_type = null() + check_status(CPyExtensionType.FromClass(storage_type.sp_type, + ExtensionType, &cpy_ext_type)) + check_status( + RegisterPyExtensionType(<shared_ptr[CDataType]> cpy_ext_type)) + + def __dealloc__(self): + # This needs to be done explicitly before the Python interpreter is + # finalized. If the C++ type is destroyed later in the process + # teardown stage, it will invoke CPython APIs such as Py_DECREF + # with a destroyed interpreter. + check_status(UnregisterPyExtensionType()) + + cdef class Field: """ A named field, with a data type, nullability, and optional metadata. @@ -1726,3 +1853,6 @@ def is_integer_value(object obj): def is_float_value(object obj): return IsPyFloat(obj) + + +_extension_types_initializer = _ExtensionTypesInitializer()