This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new eb5dd50  ARROW-840: [Python] Expose extension types
eb5dd50 is described below

commit eb5dd508ee3f592bf1c2a04cce09ee95e137e89b
Author: Antoine Pitrou <anto...@python.org>
AuthorDate: Fri Jun 14 07:53:40 2019 -0500

    ARROW-840: [Python] Expose extension types
    
    Add infrastructure to consume C++ extension types and extension arrays from 
Python.
    
    Also allow creating Python-specific extension types by subclassing 
`ExtensionType`, and creating extension arrays by passing the type and storage 
array to `ExtensionArray.from_storage`.
    
    Author: Antoine Pitrou <anto...@python.org>
    
    Closes #4532 from pitrou/ARROW-840-py-ext-types and squashes the following 
commits:
    
    95ca6148e <Antoine Pitrou> Add IPC tests
    44ac0a156 <Antoine Pitrou> ARROW-840:  Expose extension types
---
 cpp/src/arrow/array.cc                      |  11 +-
 cpp/src/arrow/extension_type.cc             |  18 +++
 cpp/src/arrow/extension_type.h              |   9 +-
 cpp/src/arrow/python/CMakeLists.txt         |   1 +
 cpp/src/arrow/python/extension_type.cc      | 196 +++++++++++++++++++++++++
 cpp/src/arrow/python/extension_type.h       |  77 ++++++++++
 cpp/src/arrow/python/pyarrow.h              |   1 +
 python/pyarrow/__init__.py                  |   4 +-
 python/pyarrow/array.pxi                    |  42 +++++-
 python/pyarrow/includes/libarrow.pxd        |  32 ++++
 python/pyarrow/lib.pxd                      |  20 ++-
 python/pyarrow/public-api.pxi               |  12 +-
 python/pyarrow/tests/test_extension_type.py | 219 ++++++++++++++++++++++++++++
 python/pyarrow/types.pxi                    | 150 +++++++++++++++++--
 14 files changed, 775 insertions(+), 17 deletions(-)

diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 7a3d36e..9d37b45 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -1259,7 +1259,16 @@ struct ValidateVisitor {
     return Status::OK();
   }
 
-  Status Visit(const ExtensionArray& array) { return 
ValidateArray(*array.storage()); }
+  Status Visit(const ExtensionArray& array) {
+    const auto& ext_type = checked_cast<const ExtensionType&>(*array.type());
+
+    if (!array.storage()->type()->Equals(*ext_type.storage_type())) {
+      return Status::Invalid("Extension array of type '", 
array.type()->ToString(),
+                             "' has storage array of incompatible type '",
+                             array.storage()->type()->ToString(), "'");
+    }
+    return ValidateArray(*array.storage());
+  }
 
  protected:
   template <typename ArrayType>
diff --git a/cpp/src/arrow/extension_type.cc b/cpp/src/arrow/extension_type.cc
index e104c03..25945f3 100644
--- a/cpp/src/arrow/extension_type.cc
+++ b/cpp/src/arrow/extension_type.cc
@@ -27,10 +27,14 @@
 #include "arrow/array.h"
 #include "arrow/status.h"
 #include "arrow/type.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
 #include "arrow/util/visibility.h"
 
 namespace arrow {
 
+using internal::checked_cast;
+
 DataTypeLayout ExtensionType::layout() const { return storage_type_->layout(); 
}
 
 std::string ExtensionType::ToString() const {
@@ -41,7 +45,21 @@ std::string ExtensionType::ToString() const {
 
 std::string ExtensionType::name() const { return "extension"; }
 
+ExtensionArray::ExtensionArray(const std::shared_ptr<ArrayData>& data) { 
SetData(data); }
+
+ExtensionArray::ExtensionArray(const std::shared_ptr<DataType>& type,
+                               const std::shared_ptr<Array>& storage) {
+  DCHECK_EQ(type->id(), Type::EXTENSION);
+  DCHECK(
+      storage->type()->Equals(*checked_cast<const 
ExtensionType&>(*type).storage_type()));
+  auto data = storage->data()->Copy();
+  // XXX This pointer is reverted below in SetData()...
+  data->type = type;
+  SetData(data);
+}
+
 void ExtensionArray::SetData(const std::shared_ptr<ArrayData>& data) {
+  DCHECK_EQ(data->type->id(), Type::EXTENSION);
   this->Array::SetData(data);
 
   auto storage_data = data->Copy();
diff --git a/cpp/src/arrow/extension_type.h b/cpp/src/arrow/extension_type.h
index b3df2b3..6a1ca0b 100644
--- a/cpp/src/arrow/extension_type.h
+++ b/cpp/src/arrow/extension_type.h
@@ -84,7 +84,14 @@ class ARROW_EXPORT ExtensionType : public DataType {
 /// \brief Base array class for user-defined extension types
 class ARROW_EXPORT ExtensionArray : public Array {
  public:
-  explicit ExtensionArray(const std::shared_ptr<ArrayData>& data) { 
SetData(data); }
+  /// \brief Construct an ExtensionArray from an ArrayData.
+  ///
+  /// The ArrayData must have the right ExtensionType.
+  explicit ExtensionArray(const std::shared_ptr<ArrayData>& data);
+
+  /// \brief Construct an ExtensionArray from a type and the underlying 
storage.
+  ExtensionArray(const std::shared_ptr<DataType>& type,
+                 const std::shared_ptr<Array>& storage);
 
   /// \brief The physical storage for the extension array
   std::shared_ptr<Array> storage() const { return storage_; }
diff --git a/cpp/src/arrow/python/CMakeLists.txt 
b/cpp/src/arrow/python/CMakeLists.txt
index d6376f5..0d17a9f 100644
--- a/cpp/src/arrow/python/CMakeLists.txt
+++ b/cpp/src/arrow/python/CMakeLists.txt
@@ -34,6 +34,7 @@ set(ARROW_PYTHON_SRCS
     config.cc
     decimal.cc
     deserialize.cc
+    extension_type.cc
     helpers.cc
     inference.cc
     init.cc
diff --git a/cpp/src/arrow/python/extension_type.cc 
b/cpp/src/arrow/python/extension_type.cc
new file mode 100644
index 0000000..b130030
--- /dev/null
+++ b/cpp/src/arrow/python/extension_type.cc
@@ -0,0 +1,196 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <memory>
+
+#include "arrow/python/extension_type.h"
+#include "arrow/python/helpers.h"
+#include "arrow/python/pyarrow.h"
+#include "arrow/util/checked_cast.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+
+using internal::checked_cast;
+
+namespace py {
+
+namespace {
+
+// Serialize a Python ExtensionType instance
+Status SerializeExtInstance(PyObject* type_instance, std::string* out) {
+  OwnedRef res(PyObject_CallMethod(type_instance, "__arrow_ext_serialize__", 
nullptr));
+  if (!res) {
+    return ConvertPyError();
+  }
+  if (!PyBytes_Check(res.obj())) {
+    return Status::TypeError(
+        "__arrow_ext_serialize__ should return bytes object, "
+        "got ",
+        internal::PyObject_StdStringRepr(res.obj()));
+  }
+  *out = internal::PyBytes_AsStdString(res.obj());
+  return Status::OK();
+}
+
+// Deserialize a Python ExtensionType instance
+PyObject* DeserializeExtInstance(PyObject* type_class,
+                                 std::shared_ptr<DataType> storage_type,
+                                 const std::string& serialized_data) {
+  OwnedRef storage_ref(wrap_data_type(storage_type));
+  if (!storage_ref) {
+    return nullptr;
+  }
+  OwnedRef data_ref(PyBytes_FromStringAndSize(
+      serialized_data.data(), 
static_cast<Py_ssize_t>(serialized_data.size())));
+  if (!data_ref) {
+    return nullptr;
+  }
+
+  return PyObject_CallMethod(type_class, "__arrow_ext_deserialize__", "OO",
+                             storage_ref.obj(), data_ref.obj());
+}
+
+}  // namespace
+
+static const char* kExtensionName = "arrow.py_extension_type";
+
+PyExtensionType::PyExtensionType(std::shared_ptr<DataType> storage_type, 
PyObject* typ,
+                                 PyObject* inst)
+    : ExtensionType(storage_type), type_class_(typ), type_instance_(inst) {}
+
+std::string PyExtensionType::extension_name() const { return kExtensionName; }
+
+bool PyExtensionType::ExtensionEquals(const ExtensionType& other) const {
+  PyAcquireGIL lock;
+
+  if (other.extension_name() != extension_name()) {
+    return false;
+  }
+  const auto& other_ext = checked_cast<const PyExtensionType&>(other);
+  int res = -1;
+  if (!type_instance_) {
+    if (other_ext.type_instance_) {
+      return false;
+    }
+    // Compare Python types
+    res = PyObject_RichCompareBool(type_class_.obj(), 
other_ext.type_class_.obj(), Py_EQ);
+  } else {
+    if (!other_ext.type_instance_) {
+      return false;
+    }
+    // Compare Python instances
+    OwnedRef left(GetInstance());
+    OwnedRef right(other_ext.GetInstance());
+    if (!left || !right) {
+      goto error;
+    }
+    res = PyObject_RichCompareBool(left.obj(), right.obj(), Py_EQ);
+  }
+  if (res == -1) {
+    goto error;
+  }
+  return res == 1;
+
+error:
+  // Cannot propagate error
+  PyErr_WriteUnraisable(nullptr);
+  return false;
+}
+
+std::shared_ptr<Array> PyExtensionType::MakeArray(std::shared_ptr<ArrayData> 
data) const {
+  DCHECK_EQ(data->type->id(), Type::EXTENSION);
+  DCHECK_EQ(kExtensionName,
+            checked_cast<const ExtensionType&>(*data->type).extension_name());
+  return std::make_shared<ExtensionArray>(data);
+}
+
+std::string PyExtensionType::Serialize() const {
+  DCHECK(type_instance_);
+  return serialized_;
+}
+
+Status PyExtensionType::Deserialize(std::shared_ptr<DataType> storage_type,
+                                    const std::string& serialized_data,
+                                    std::shared_ptr<DataType>* out) const {
+  PyAcquireGIL lock;
+
+  if (import_pyarrow()) {
+    return ConvertPyError();
+  }
+  OwnedRef res(DeserializeExtInstance(type_class_.obj(), storage_type, 
serialized_data));
+  if (!res) {
+    return ConvertPyError();
+  }
+  return unwrap_data_type(res.obj(), out);
+}
+
+PyObject* PyExtensionType::GetInstance() const {
+  if (!type_instance_) {
+    PyErr_SetString(PyExc_TypeError, "Not an instance");
+    return nullptr;
+  }
+  DCHECK(PyWeakref_CheckRef(type_instance_.obj()));
+  PyObject* inst = PyWeakref_GET_OBJECT(type_instance_.obj());
+  if (inst != Py_None) {
+    // Cached instance still alive
+    Py_INCREF(inst);
+    return inst;
+  } else {
+    // Must reconstruct from serialized form
+    // XXX cache again?
+    return DeserializeExtInstance(type_class_.obj(), storage_type_, 
serialized_);
+  }
+}
+
+Status PyExtensionType::SetInstance(PyObject* inst) const {
+  // Check we have the right type
+  PyObject* typ = reinterpret_cast<PyObject*>(Py_TYPE(inst));
+  if (typ != type_class_.obj()) {
+    return Status::TypeError("Unexpected Python ExtensionType class ",
+                             internal::PyObject_StdStringRepr(typ), " expected 
",
+                             
internal::PyObject_StdStringRepr(type_class_.obj()));
+  }
+
+  PyObject* wr = PyWeakref_NewRef(inst, nullptr);
+  if (wr == NULL) {
+    return ConvertPyError();
+  }
+  type_instance_.reset(wr);
+  return SerializeExtInstance(inst, &serialized_);
+}
+
+Status PyExtensionType::FromClass(std::shared_ptr<DataType> storage_type, 
PyObject* typ,
+                                  std::shared_ptr<ExtensionType>* out) {
+  Py_INCREF(typ);
+  out->reset(new PyExtensionType(storage_type, typ));
+  return Status::OK();
+}
+
+Status RegisterPyExtensionType(const std::shared_ptr<DataType>& type) {
+  DCHECK_EQ(type->id(), Type::EXTENSION);
+  auto ext_type = std::dynamic_pointer_cast<ExtensionType>(type);
+  DCHECK_EQ(ext_type->extension_name(), kExtensionName);
+  return RegisterExtensionType(ext_type);
+}
+
+Status UnregisterPyExtensionType() { return 
UnregisterExtensionType(kExtensionName); }
+
+std::string PyExtensionName() { return kExtensionName; }
+
+}  // namespace py
+}  // namespace arrow
diff --git a/cpp/src/arrow/python/extension_type.h 
b/cpp/src/arrow/python/extension_type.h
new file mode 100644
index 0000000..12f9108
--- /dev/null
+++ b/cpp/src/arrow/python/extension_type.h
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include <memory>
+#include <string>
+
+#include "arrow/extension_type.h"
+#include "arrow/python/common.h"
+#include "arrow/python/visibility.h"
+#include "arrow/util/macros.h"
+
+namespace arrow {
+namespace py {
+
+class ARROW_PYTHON_EXPORT PyExtensionType : public ExtensionType {
+ public:
+  // Implement extensionType API
+  std::string extension_name() const override;
+
+  bool ExtensionEquals(const ExtensionType& other) const override;
+
+  std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const 
override;
+
+  Status Deserialize(std::shared_ptr<DataType> storage_type,
+                     const std::string& serialized_data,
+                     std::shared_ptr<DataType>* out) const override;
+
+  std::string Serialize() const override;
+
+  // For use from Cython
+  static Status FromClass(std::shared_ptr<DataType> storage_type, PyObject* 
typ,
+                          std::shared_ptr<ExtensionType>* out);
+
+  // Return new ref
+  PyObject* GetInstance() const;
+  Status SetInstance(PyObject*) const;
+
+ protected:
+  PyExtensionType(std::shared_ptr<DataType> storage_type, PyObject* typ,
+                  PyObject* inst = NULLPTR);
+
+  // These fields are mutable because of two-step initialization.
+  mutable OwnedRefNoGIL type_class_;
+  // A weakref or null.  Storing a strong reference to the Python extension 
type
+  // instance would create an unreclaimable reference cycle between Python and 
C++
+  // (the Python instance has to keep a strong reference to the C++ 
ExtensionType
+  //  in other direction).  Instead, we store a weakref to the instance.
+  // If the weakref is dead, we reconstruct the instance from its serialized 
form.
+  mutable OwnedRefNoGIL type_instance_;
+  // Empty if type_instance_ is null
+  mutable std::string serialized_;
+};
+
+ARROW_PYTHON_EXPORT std::string PyExtensionName();
+
+ARROW_PYTHON_EXPORT Status RegisterPyExtensionType(const 
std::shared_ptr<DataType>&);
+
+ARROW_PYTHON_EXPORT Status UnregisterPyExtensionType();
+
+}  // namespace py
+}  // namespace arrow
diff --git a/cpp/src/arrow/python/pyarrow.h b/cpp/src/arrow/python/pyarrow.h
index a5a3910..5e42333 100644
--- a/cpp/src/arrow/python/pyarrow.h
+++ b/cpp/src/arrow/python/pyarrow.h
@@ -39,6 +39,7 @@ class Tensor;
 
 namespace py {
 
+// Returns 0 on success, -1 on error.
 ARROW_PYTHON_EXPORT int import_pyarrow();
 
 ARROW_PYTHON_EXPORT bool is_buffer(PyObject* buffer);
diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py
index f9ba819..556b87d 100644
--- a/python/pyarrow/__init__.py
+++ b/python/pyarrow/__init__.py
@@ -58,6 +58,8 @@ from pyarrow.lib import (null, bool_,
                          DataType, DictionaryType, ListType, StructType,
                          UnionType, TimestampType, Time32Type, Time64Type,
                          FixedSizeBinaryType, Decimal128Type,
+                         BaseExtensionType, ExtensionType,
+                         UnknownExtensionType,
                          DictionaryMemo,
                          Field,
                          Schema,
@@ -78,7 +80,7 @@ from pyarrow.lib import (null, bool_,
                          DictionaryArray,
                          Date32Array, Date64Array,
                          TimestampArray, Time32Array, Time64Array,
-                         Decimal128Array, StructArray,
+                         Decimal128Array, StructArray, ExtensionArray,
                          ArrayValue, Scalar, NA, _NULL as NULL,
                          BooleanValue,
                          Int8Value, Int16Value, Int32Value, Int64Value,
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index cce967e..607d7ae 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -415,7 +415,7 @@ cdef class Array(_PandasConvertible):
                         "the `pyarrow.Array.from_*` functions instead."
                         .format(self.__class__.__name__))
 
-    cdef void init(self, const shared_ptr[CArray]& sp_array):
+    cdef void init(self, const shared_ptr[CArray]& sp_array) except *:
         self.sp_array = sp_array
         self.ap = sp_array.get()
         self.type = pyarrow_wrap_data_type(self.sp_array.get().type())
@@ -1458,6 +1458,45 @@ cdef class StructArray(Array):
         return pyarrow_wrap_array(c_result)
 
 
+cdef class ExtensionArray(Array):
+    """
+    Concrete class for Arrow extension arrays.
+    """
+
+    @property
+    def storage(self):
+        cdef:
+            CExtensionArray* ext_array = <CExtensionArray*>(self.ap)
+
+        return pyarrow_wrap_array(ext_array.storage())
+
+    @staticmethod
+    def from_storage(BaseExtensionType typ, Array storage):
+        """
+        Construct ExtensionArray from type and storage array.
+
+        Parameters
+        ----------
+        typ: DataType
+            The extension type for the result array.
+        storage: Array
+            The underlying storage for the result array.
+
+        Returns
+        -------
+        ext_array : ExtensionArray
+        """
+        cdef:
+            shared_ptr[CExtensionArray] ext_array
+
+        if storage.type != typ.storage_type:
+            raise TypeError("Incompatible storage type {0} "
+                            "for extension type {1}".format(storage.type, typ))
+
+        ext_array = make_shared[CExtensionArray](typ.sp_type, storage.sp_array)
+        return pyarrow_wrap_array(<shared_ptr[CArray]> ext_array)
+
+
 cdef dict _array_classes = {
     _Type_NA: NullArray,
     _Type_BOOL: BooleanArray,
@@ -1485,6 +1524,7 @@ cdef dict _array_classes = {
     _Type_FIXED_SIZE_BINARY: FixedSizeBinaryArray,
     _Type_DECIMAL: Decimal128Array,
     _Type_STRUCT: StructArray,
+    _Type_EXTENSION: ExtensionArray,
 }
 
 
diff --git a/python/pyarrow/includes/libarrow.pxd 
b/python/pyarrow/includes/libarrow.pxd
index f979cd6..178a250 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -73,6 +73,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
         _Type_DICTIONARY" arrow::Type::DICTIONARY"
         _Type_MAP" arrow::Type::MAP"
 
+        _Type_EXTENSION" arrow::Type::EXTENSION"
+
     enum UnionMode" arrow::UnionMode::type":
         _UnionMode_SPARSE" arrow::UnionMode::SPARSE"
         _UnionMode_DENSE" arrow::UnionMode::DENSE"
@@ -1272,6 +1274,36 @@ cdef extern from 'arrow/python/inference.h' namespace 
'arrow::py':
     c_bool IsPyFloat(object o)
 
 
+cdef extern from 'arrow/extension_type.h' namespace 'arrow':
+    cdef cppclass CExtensionType" arrow::ExtensionType"(CDataType):
+        c_string extension_name()
+        shared_ptr[CDataType] storage_type()
+
+    cdef cppclass CExtensionArray" arrow::ExtensionArray"(CArray):
+        CExtensionArray(shared_ptr[CDataType], shared_ptr[CArray] storage)
+
+        shared_ptr[CArray] storage()
+
+
+cdef extern from 'arrow/python/extension_type.h' namespace 'arrow::py':
+    cdef cppclass CPyExtensionType \
+            " arrow::py::PyExtensionType"(CExtensionType):
+        @staticmethod
+        CStatus FromClass(shared_ptr[CDataType] storage_type,
+                          object typ, shared_ptr[CExtensionType]* out)
+
+        @staticmethod
+        CStatus FromInstance(shared_ptr[CDataType] storage_type,
+                             object inst, shared_ptr[CExtensionType]* out)
+
+        object GetInstance()
+        CStatus SetInstance(object)
+
+    c_string PyExtensionName()
+    CStatus RegisterPyExtensionType(shared_ptr[CDataType])
+    CStatus UnregisterPyExtensionType()
+
+
 cdef extern from 'arrow/python/benchmark.h' namespace 'arrow::py::benchmark':
     void Benchmark_PandasObjectIsNull(object lst) except *
 
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 998848d..79ab947 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -53,8 +53,9 @@ cdef class DataType:
         shared_ptr[CDataType] sp_type
         CDataType* type
         bytes pep3118_format
+        object __weakref__
 
-    cdef void init(self, const shared_ptr[CDataType]& type)
+    cdef void init(self, const shared_ptr[CDataType]& type) except *
     cdef Field child(self, int i)
 
 
@@ -106,6 +107,16 @@ cdef class Decimal128Type(FixedSizeBinaryType):
         const CDecimal128Type* decimal128_type
 
 
+cdef class BaseExtensionType(DataType):
+    cdef:
+        const CExtensionType* ext_type
+
+
+cdef class ExtensionType(BaseExtensionType):
+    cdef:
+        const CPyExtensionType* cpy_ext_type
+
+
 cdef class Field:
     cdef:
         shared_ptr[CField] sp_field
@@ -199,11 +210,12 @@ cdef class Array(_PandasConvertible):
     cdef:
         shared_ptr[CArray] sp_array
         CArray* ap
+        object __weakref__
 
     cdef readonly:
         DataType type
 
-    cdef void init(self, const shared_ptr[CArray]& sp_array)
+    cdef void init(self, const shared_ptr[CArray]& sp_array) except *
     cdef getitem(self, int64_t i)
     cdef int64_t length(self)
 
@@ -316,6 +328,10 @@ cdef class DictionaryArray(Array):
         object _indices, _dictionary
 
 
+cdef class ExtensionArray(Array):
+    pass
+
+
 cdef wrap_array_output(PyObject* output)
 cdef object box_scalar(DataType type,
                        const shared_ptr[CArray]& sp_array,
diff --git a/python/pyarrow/public-api.pxi b/python/pyarrow/public-api.pxi
index 9392259..33bc803 100644
--- a/python/pyarrow/public-api.pxi
+++ b/python/pyarrow/public-api.pxi
@@ -66,7 +66,10 @@ cdef api shared_ptr[CDataType] pyarrow_unwrap_data_type(
 
 cdef api object pyarrow_wrap_data_type(
         const shared_ptr[CDataType]& type):
-    cdef DataType out
+    cdef:
+        const CExtensionType* ext_type
+        const CPyExtensionType* cpy_ext_type
+        DataType out
 
     if type.get() == NULL:
         return None
@@ -85,6 +88,13 @@ cdef api object pyarrow_wrap_data_type(
         out = FixedSizeBinaryType.__new__(FixedSizeBinaryType)
     elif type.get().id() == _Type_DECIMAL:
         out = Decimal128Type.__new__(Decimal128Type)
+    elif type.get().id() == _Type_EXTENSION:
+        ext_type = <const CExtensionType*> type.get()
+        if ext_type.extension_name() == PyExtensionName():
+            cpy_ext_type = <const CPyExtensionType*> ext_type
+            return cpy_ext_type.GetInstance()
+        else:
+            out = BaseExtensionType.__new__(BaseExtensionType)
     else:
         out = DataType.__new__(DataType)
 
diff --git a/python/pyarrow/tests/test_extension_type.py 
b/python/pyarrow/tests/test_extension_type.py
new file mode 100644
index 0000000..d688d3c
--- /dev/null
+++ b/python/pyarrow/tests/test_extension_type.py
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pickle
+import weakref
+
+import pyarrow as pa
+
+import pytest
+
+
+class UuidType(pa.ExtensionType):
+
+    def __init__(self):
+        pa.ExtensionType.__init__(self, pa.binary(16))
+
+    def __reduce__(self):
+        return UuidType, ()
+
+
+class ParamExtType(pa.ExtensionType):
+
+    def __init__(self, width):
+        self.width = width
+        pa.ExtensionType.__init__(self, pa.binary(width))
+
+    def __reduce__(self):
+        return ParamExtType, (self.width,)
+
+
+def ipc_write_batch(batch):
+    stream = pa.BufferOutputStream()
+    writer = pa.RecordBatchStreamWriter(stream, batch.schema)
+    writer.write_batch(batch)
+    writer.close()
+    return stream.getvalue()
+
+
+def ipc_read_batch(buf):
+    reader = pa.RecordBatchStreamReader(buf)
+    return reader.read_next_batch()
+
+
+def test_ext_type_basics():
+    ty = UuidType()
+    assert ty.extension_name == "arrow.py_extension_type"
+
+
+def test_ext_type__lifetime():
+    ty = UuidType()
+    wr = weakref.ref(ty)
+    del ty
+    assert wr() is None
+
+
+def test_ext_type__storage_type():
+    ty = UuidType()
+    assert ty.storage_type == pa.binary(16)
+    assert ty.__class__ is UuidType
+    ty = ParamExtType(5)
+    assert ty.storage_type == pa.binary(5)
+    assert ty.__class__ is ParamExtType
+
+
+def test_uuid_type_pickle():
+    for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+        ty = UuidType()
+        ser = pickle.dumps(ty, protocol=proto)
+        del ty
+        ty = pickle.loads(ser)
+        wr = weakref.ref(ty)
+        assert ty.extension_name == "arrow.py_extension_type"
+        del ty
+        assert wr() is None
+
+
+def test_ext_type_equality():
+    a = ParamExtType(5)
+    b = ParamExtType(6)
+    c = ParamExtType(6)
+    assert a != b
+    assert b == c
+    d = UuidType()
+    e = UuidType()
+    assert a != d
+    assert d == e
+
+
+def test_ext_array_basics():
+    ty = ParamExtType(3)
+    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+    arr = pa.ExtensionArray.from_storage(ty, storage)
+    arr.validate()
+    assert arr.type is ty
+    assert arr.storage.equals(storage)
+
+
+def test_ext_array_lifetime():
+    ty = ParamExtType(3)
+    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+    arr = pa.ExtensionArray.from_storage(ty, storage)
+
+    refs = [weakref.ref(obj) for obj in (ty, arr, storage)]
+    del ty, storage, arr
+    for ref in refs:
+        assert ref() is None
+
+
+def test_ext_array_errors():
+    ty = ParamExtType(4)
+    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+    with pytest.raises(TypeError, match="Incompatible storage type"):
+        pa.ExtensionArray.from_storage(ty, storage)
+
+
+def test_ext_array_equality():
+    storage1 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+    storage2 = pa.array([b"0123456789abcdef"], type=pa.binary(16))
+    storage3 = pa.array([], type=pa.binary(16))
+    ty1 = UuidType()
+    ty2 = ParamExtType(16)
+
+    a = pa.ExtensionArray.from_storage(ty1, storage1)
+    b = pa.ExtensionArray.from_storage(ty1, storage2)
+    assert a.equals(b)
+    c = pa.ExtensionArray.from_storage(ty1, storage3)
+    assert not a.equals(c)
+    d = pa.ExtensionArray.from_storage(ty2, storage1)
+    assert not a.equals(d)
+    e = pa.ExtensionArray.from_storage(ty2, storage2)
+    assert d.equals(e)
+    f = pa.ExtensionArray.from_storage(ty2, storage3)
+    assert not d.equals(f)
+
+
+def test_ext_array_pickling():
+    for proto in range(0, pickle.HIGHEST_PROTOCOL + 1):
+        ty = ParamExtType(3)
+        storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+        arr = pa.ExtensionArray.from_storage(ty, storage)
+        ser = pickle.dumps(arr, protocol=proto)
+        del ty, storage, arr
+        arr = pickle.loads(ser)
+        arr.validate()
+        assert isinstance(arr, pa.ExtensionArray)
+        assert arr.type == ParamExtType(3)
+        assert arr.type.storage_type == pa.binary(3)
+        assert arr.storage.type == pa.binary(3)
+        assert arr.storage.to_pylist() == [b"foo", b"bar"]
+
+
+def example_batch():
+    ty = ParamExtType(3)
+    storage = pa.array([b"foo", b"bar"], type=pa.binary(3))
+    arr = pa.ExtensionArray.from_storage(ty, storage)
+    return pa.RecordBatch.from_arrays([arr], ["exts"])
+
+
+def check_example_batch(batch):
+    arr = batch.column(0)
+    assert isinstance(arr, pa.ExtensionArray)
+    assert arr.type.storage_type == pa.binary(3)
+    assert arr.storage.to_pylist() == [b"foo", b"bar"]
+    return arr
+
+
+def test_ipc():
+    batch = example_batch()
+    buf = ipc_write_batch(batch)
+    del batch
+
+    batch = ipc_read_batch(buf)
+    arr = check_example_batch(batch)
+    assert arr.type == ParamExtType(3)
+
+
+def test_ipc_unknown_type():
+    batch = example_batch()
+    buf = ipc_write_batch(batch)
+    del batch
+
+    orig_type = ParamExtType
+    try:
+        # Simulate the original Python type being unavailable.
+        # Deserialization should not fail but return a placeholder type.
+        del globals()['ParamExtType']
+
+        batch = ipc_read_batch(buf)
+        arr = check_example_batch(batch)
+        assert isinstance(arr.type, pa.UnknownExtensionType)
+
+        # Can be serialized again
+        buf2 = ipc_write_batch(batch)
+        del batch, arr
+
+        batch = ipc_read_batch(buf2)
+        arr = check_example_batch(batch)
+        assert isinstance(arr.type, pa.UnknownExtensionType)
+    finally:
+        globals()['ParamExtType'] = orig_type
+
+    # Deserialize again with the type restored
+    batch = ipc_read_batch(buf2)
+    arr = check_example_batch(batch)
+    assert arr.type == ParamExtType(3)
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 9a92761..1f0db4c 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -19,6 +19,7 @@ import re
 import warnings
 
 from pyarrow import compat
+from pyarrow.compat import builtin_pickle
 
 
 # These are imprecise because the type (in pandas 0.x) depends on the presence
@@ -103,7 +104,7 @@ cdef class DataType:
                         "functions like pyarrow.int64, pyarrow.list_, etc. "
                         "instead.".format(self.__class__.__name__))
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         self.sp_type = type
         self.type = type.get()
         self.pep3118_format = _datatype_to_pep3118(self.type)
@@ -203,7 +204,7 @@ cdef class DictionaryType(DataType):
     Concrete class for dictionary data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.dict_type = <const CDictionaryType*> type.get()
 
@@ -239,7 +240,7 @@ cdef class ListType(DataType):
     Concrete class for list data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.list_type = <const CListType*> type.get()
 
@@ -259,7 +260,7 @@ cdef class StructType(DataType):
     Concrete class for struct data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.struct_type = <const CStructType*> type.get()
 
@@ -318,7 +319,7 @@ cdef class UnionType(DataType):
     Concrete class for struct data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
 
     @property
@@ -370,7 +371,7 @@ cdef class TimestampType(DataType):
     Concrete class for timestamp data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.ts_type = <const CTimestampType*> type.get()
 
@@ -411,7 +412,7 @@ cdef class Time32Type(DataType):
     Concrete class for time32 data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.time_type = <const CTime32Type*> type.get()
 
@@ -428,7 +429,7 @@ cdef class Time64Type(DataType):
     Concrete class for time64 data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.time_type = <const CTime64Type*> type.get()
 
@@ -445,7 +446,7 @@ cdef class FixedSizeBinaryType(DataType):
     Concrete class for fixed-size binary data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         DataType.init(self, type)
         self.fixed_size_binary_type = (
             <const CFixedSizeBinaryType*> type.get())
@@ -466,7 +467,7 @@ cdef class Decimal128Type(FixedSizeBinaryType):
     Concrete class for decimal128 data types.
     """
 
-    cdef void init(self, const shared_ptr[CDataType]& type):
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
         FixedSizeBinaryType.init(self, type)
         self.decimal128_type = <const CDecimal128Type*> type.get()
 
@@ -488,6 +489,132 @@ cdef class Decimal128Type(FixedSizeBinaryType):
         return self.decimal128_type.scale()
 
 
+cdef class BaseExtensionType(DataType):
+    """
+    Concrete base class for extension types.
+    """
+
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
+        DataType.init(self, type)
+        self.ext_type = <const CExtensionType*> type.get()
+
+    @property
+    def extension_name(self):
+        """
+        The extension type name.
+        """
+        return frombytes(self.ext_type.extension_name())
+
+    @property
+    def storage_type(self):
+        """
+        The underlying storage type.
+        """
+        return pyarrow_wrap_data_type(self.ext_type.storage_type())
+
+
+cdef class ExtensionType(BaseExtensionType):
+    """
+    Concrete base class for Python-defined extension types.
+    """
+
+    def __cinit__(self):
+        if type(self) is ExtensionType:
+            raise TypeError("Can only instantiate subclasses of "
+                            "ExtensionType")
+
+    def __init__(self, DataType storage_type):
+        cdef:
+            shared_ptr[CExtensionType] cpy_ext_type
+
+        assert storage_type is not None
+        check_status(CPyExtensionType.FromClass(storage_type.sp_type,
+                                                type(self), &cpy_ext_type))
+        self.init(<shared_ptr[CDataType]> cpy_ext_type)
+
+    cdef void init(self, const shared_ptr[CDataType]& type) except *:
+        BaseExtensionType.init(self, type)
+        self.cpy_ext_type = <const CPyExtensionType*> type.get()
+        # Store weakref and serialized version of self on C++ type instance
+        check_status(self.cpy_ext_type.SetInstance(self))
+
+    def __eq__(self, other):
+        # Default implementation to avoid infinite recursion through
+        # DataType.__eq__ -> ExtensionType::ExtensionEquals -> DataType.__eq__
+        if isinstance(other, ExtensionType):
+            return (type(self) == type(other) and
+                    self.storage_type == other.storage_type)
+        else:
+            return NotImplemented
+
+    def __reduce__(self):
+        raise NotImplementedError("Please implement {0}.__reduce__"
+                                  .format(type(self).__name__))
+
+    def __arrow_ext_serialize__(self):
+        return builtin_pickle.dumps(self)
+
+    @classmethod
+    def __arrow_ext_deserialize__(cls, storage_type, serialized):
+        try:
+            ty = builtin_pickle.loads(serialized)
+        except Exception:
+            # For some reason, it's impossible to deserialize the
+            # ExtensionType instance.  Perhaps the serialized data is
+            # corrupt, or more likely the type is being deserialized
+            # in an environment where the original Python class or module
+            # is not available.  Fall back on a generic BaseExtensionType.
+            return UnknownExtensionType(storage_type, serialized)
+
+        if ty.storage_type != storage_type:
+            raise TypeError("Expected storage type {0} but got {1}"
+                            .format(ty.storage_type, storage_type))
+        return ty
+
+
+cdef class UnknownExtensionType(ExtensionType):
+    """
+    A concrete class for Python-defined extension types that refer to
+    an unknown Python implementation.
+    """
+
+    cdef:
+        bytes serialized
+
+    def __init__(self, DataType storage_type, serialized):
+        self.serialized = serialized
+        ExtensionType.__init__(self, storage_type)
+
+    def __arrow_ext_serialize__(self):
+        return self.serialized
+
+
+cdef class _ExtensionTypesInitializer:
+    #
+    # A private object that handles process-wide registration of the Python
+    # ExtensionType.
+    #
+
+    def __cinit__(self):
+        cdef:
+            DataType storage_type
+            shared_ptr[CExtensionType] cpy_ext_type
+
+        # Make a dummy C++ ExtensionType
+        storage_type = null()
+        check_status(CPyExtensionType.FromClass(storage_type.sp_type,
+                                                ExtensionType, &cpy_ext_type))
+        check_status(
+            RegisterPyExtensionType(<shared_ptr[CDataType]> cpy_ext_type))
+
+    def __dealloc__(self):
+        # This needs to be done explicitly before the Python interpreter is
+        # finalized.  If the C++ type is destroyed later in the process
+        # teardown stage, it will invoke CPython APIs such as Py_DECREF
+        # with a destroyed interpreter.
+        check_status(UnregisterPyExtensionType())
+
+
 cdef class Field:
     """
     A named field, with a data type, nullability, and optional metadata.
@@ -1726,3 +1853,6 @@ def is_integer_value(object obj):
 
 def is_float_value(object obj):
     return IsPyFloat(obj)
+
+
+_extension_types_initializer = _ExtensionTypesInitializer()

Reply via email to