This is an automated email from the ASF dual-hosted git repository.
wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new 56e72ba ARROW-3278: [Python] Retrieve StructType's and StructArray's
field by name
56e72ba is described below
commit 56e72ba09c3d886c6a5aeb11fb1642af13819f93
Author: Krisztián Szűcs <[email protected]>
AuthorDate: Wed Oct 17 14:35:21 2018 -0400
ARROW-3278: [Python] Retrieve StructType's and StructArray's field by name
Author: Krisztián Szűcs <[email protected]>
Closes #2754 from kszucs/ARROW-3278 and squashes the following commits:
737ca2989 <Krisztián Szűcs> int cast for _normalize_index
1f224d561 <Krisztián Szűcs> remove StructType.field_by_name
9f7b9c04c <Krisztián Szűcs> overload field
7523eefdd <Krisztián Szűcs> StructArray.field_by_name
4a6f2d5c2 <Krisztián Szűcs> StructType.field_by_name
---
cpp/src/arrow/array-test.cc | 4 ++++
cpp/src/arrow/array.cc | 9 +++++++++
cpp/src/arrow/array.h | 5 +++++
python/pyarrow/array.pxi | 21 ++++++++++++++++-----
python/pyarrow/includes/libarrow.pxd | 1 +
python/pyarrow/lib.pxd | 8 ++++++++
python/pyarrow/tests/test_array.py | 35 +++++++++++++++++++++--------------
python/pyarrow/tests/test_types.py | 12 ++++++++++++
python/pyarrow/types.pxi | 28 ++++++++++++++++++++++++----
9 files changed, 100 insertions(+), 23 deletions(-)
diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc
index 4a50232..78ff4bc 100644
--- a/cpp/src/arrow/array-test.cc
+++ b/cpp/src/arrow/array-test.cc
@@ -3239,6 +3239,10 @@ void ValidateBasicStructArray(const StructArray* result,
auto char_arr =
std::dynamic_pointer_cast<Int8Array>(list_char_arr->values());
auto int32_arr = std::dynamic_pointer_cast<Int32Array>(result->field(1));
+ ASSERT_EQ(nullptr, result->GetFieldByName("non-existing"));
+ ASSERT_TRUE(list_char_arr->Equals(result->GetFieldByName("list")));
+ ASSERT_TRUE(int32_arr->Equals(result->GetFieldByName("int")));
+
ASSERT_EQ(0, result->null_count());
ASSERT_EQ(1, list_char_arr->null_count());
ASSERT_EQ(0, int32_arr->null_count());
diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc
index 3d88618..05d66d5 100644
--- a/cpp/src/arrow/array.cc
+++ b/cpp/src/arrow/array.cc
@@ -374,6 +374,10 @@ StructArray::StructArray(const std::shared_ptr<DataType>&
type, int64_t length,
boxed_fields_.resize(children.size());
}
+const StructType* StructArray::struct_type() const {
+ return checked_cast<const StructType*>(data_->type.get());
+}
+
std::shared_ptr<Array> StructArray::field(int i) const {
if (!boxed_fields_[i]) {
std::shared_ptr<ArrayData> field_data;
@@ -388,6 +392,11 @@ std::shared_ptr<Array> StructArray::field(int i) const {
return boxed_fields_[i];
}
+std::shared_ptr<Array> StructArray::GetFieldByName(const std::string& name)
const {
+ int i = struct_type()->GetChildIndex(name);
+ return i == -1 ? nullptr : field(i);
+}
+
Status StructArray::Flatten(MemoryPool* pool, ArrayVector* out) const {
ArrayVector flattened;
std::shared_ptr<Buffer> null_bitmap = data_->buffers[0];
diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h
index 8c5e0a5..be64ebc 100644
--- a/cpp/src/arrow/array.h
+++ b/cpp/src/arrow/array.h
@@ -631,11 +631,16 @@ class ARROW_EXPORT StructArray : public Array {
std::shared_ptr<Buffer> null_bitmap = NULLPTR, int64_t
null_count = 0,
int64_t offset = 0);
+ const StructType* struct_type() const;
+
// Return a shared pointer in case the requestor desires to share ownership
// with this array. The returned array has its offset, length and null
// count adjusted.
std::shared_ptr<Array> field(int pos) const;
+ /// Returns null if name not found
+ std::shared_ptr<Array> GetFieldByName(const std::string& name) const;
+
/// \brief Flatten this array as a vector of arrays, one for each field
///
/// \param[in] pool The pool to allocate null bitmaps from, if necessary
diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi
index 2d0f56d..a62752c 100644
--- a/python/pyarrow/array.pxi
+++ b/python/pyarrow/array.pxi
@@ -1145,17 +1145,28 @@ cdef class StructArray(Array):
Parameters
----------
- index : int
- Index / position of the field
+ index : Union[int, str]
+ Index / position or name of the field
Returns
-------
result : Array
"""
cdef:
- int ix = <int> _normalize_index(index, self.ap.num_fields())
- CStructArray* sarr = <CStructArray*> self.ap
- return pyarrow_wrap_array(sarr.field(ix))
+ CStructArray* arr = <CStructArray*> self.ap
+ shared_ptr[CArray] child
+
+ if isinstance(index, six.string_types):
+ child = arr.GetFieldByName(tobytes(index))
+ if child == nullptr:
+ raise KeyError(index)
+ elif isinstance(index, six.integer_types):
+ child = arr.field(
+ <int>_normalize_index(index, self.ap.num_fields()))
+ else:
+ raise TypeError('Expected integer or string index')
+
+ return pyarrow_wrap_array(child)
def flatten(self, MemoryPool memory_pool=None):
"""
diff --git a/python/pyarrow/includes/libarrow.pxd
b/python/pyarrow/includes/libarrow.pxd
index e4090b2..2a69446 100644
--- a/python/pyarrow/includes/libarrow.pxd
+++ b/python/pyarrow/includes/libarrow.pxd
@@ -429,6 +429,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
int64_t offset=0)
shared_ptr[CArray] field(int pos)
+ shared_ptr[CArray] GetFieldByName(const c_string& name) const
CStatus Flatten(CMemoryPool* pool, vector[shared_ptr[CArray]]* out)
diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd
index 7749854..3c2935a 100644
--- a/python/pyarrow/lib.pxd
+++ b/python/pyarrow/lib.pxd
@@ -51,6 +51,7 @@ cdef class DataType:
bytes pep3118_format
cdef void init(self, const shared_ptr[CDataType]& type)
+ cdef Field child(self, int i)
cdef class ListType(DataType):
@@ -58,6 +59,13 @@ cdef class ListType(DataType):
const CListType* list_type
+cdef class StructType(DataType):
+ cdef:
+ const CStructType* struct_type
+
+ cdef Field child_by_name(self, name)
+
+
cdef class DictionaryType(DataType):
cdef:
const CDictionaryType* dict_type
diff --git a/python/pyarrow/tests/test_array.py
b/python/pyarrow/tests/test_array.py
index c340228..1350ad6 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -1184,20 +1184,23 @@ def test_struct_array_field():
pa.field('y', pa.float32())])
a = pa.array([(1, 2.5), (3, 4.5), (5, 6.5)], type=ty)
- x = a.field(0)
- y = a.field(1)
- x_ = a.field(-2)
- y_ = a.field(-1)
-
- assert isinstance(x, pa.lib.Int16Array)
- assert isinstance(y, pa.lib.FloatArray)
-
- assert x.equals(pa.array([1, 3, 5], type=pa.int16()))
- assert y.equals(pa.array([2.5, 4.5, 6.5], type=pa.float32()))
- assert x.equals(x_)
- assert y.equals(y_)
-
- for invalid_index in [None, 'x']:
+ x0 = a.field(0)
+ y0 = a.field(1)
+ x1 = a.field(-2)
+ y1 = a.field(-1)
+ x2 = a.field('x')
+ y2 = a.field('y')
+
+ assert isinstance(x0, pa.lib.Int16Array)
+ assert isinstance(y1, pa.lib.FloatArray)
+ assert x0.equals(pa.array([1, 3, 5], type=pa.int16()))
+ assert y0.equals(pa.array([2.5, 4.5, 6.5], type=pa.float32()))
+ assert x0.equals(x1)
+ assert x0.equals(x2)
+ assert y0.equals(y1)
+ assert y0.equals(y2)
+
+ for invalid_index in [None, pa.int16()]:
with pytest.raises(TypeError):
a.field(invalid_index)
@@ -1205,6 +1208,10 @@ def test_struct_array_field():
with pytest.raises(IndexError):
a.field(invalid_index)
+ for invalid_name in ['z', '']:
+ with pytest.raises(KeyError):
+ a.field(invalid_name)
+
def test_nested_dictionary_array():
dict_arr = pa.DictionaryArray.from_arrays([0, 1, 0], ['a', 'b'])
diff --git a/python/pyarrow/tests/test_types.py
b/python/pyarrow/tests/test_types.py
index bd73fe0..b574713 100644
--- a/python/pyarrow/tests/test_types.py
+++ b/python/pyarrow/tests/test_types.py
@@ -231,6 +231,18 @@ def test_struct_type():
assert len(ty) == ty.num_children == 3
assert list(ty) == fields
+ assert ty[0].name == 'a'
+ assert ty[2].type == pa.int32()
+ with pytest.raises(IndexError):
+ assert ty[3]
+
+ assert ty['a'] == ty[1]
+ assert ty['b'] == ty[2]
+ with pytest.raises(KeyError):
+ ty['c']
+
+ with pytest.raises(TypeError):
+ ty[None]
for a, b in zip(ty, fields):
a == b
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index ee7f890..92ef0f3 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -103,6 +103,10 @@ cdef class DataType:
self.type = type.get()
self.pep3118_format = _datatype_to_pep3118(self.type)
+ cdef Field child(self, int i):
+ cdef int index = <int> _normalize_index(i, self.type.num_children())
+ return pyarrow_wrap_field(self.type.child(index))
+
@property
def id(self):
return self.type.id()
@@ -207,6 +211,19 @@ cdef class StructType(DataType):
cdef void init(self, const shared_ptr[CDataType]& type):
DataType.init(self, type)
+ self.struct_type = <const CStructType*> type.get()
+
+ cdef Field child_by_name(self, name):
+ """
+ Access a child field by its name rather than the column index.
+ """
+ cdef shared_ptr[CField] field
+
+ field = self.struct_type.GetChildByName(tobytes(name))
+ if field == nullptr:
+ raise KeyError(name)
+
+ return pyarrow_wrap_field(field)
def __len__(self):
return self.type.num_children()
@@ -216,8 +233,12 @@ cdef class StructType(DataType):
yield self[i]
def __getitem__(self, i):
- cdef int index = <int> _normalize_index(i, self.num_children)
- return pyarrow_wrap_field(self.type.child(index))
+ if isinstance(i, six.string_types):
+ return self.child_by_name(i)
+ elif isinstance(i, six.integer_types):
+ return self.child(i)
+ else:
+ raise TypeError('Expected integer or string index')
def __reduce__(self):
return struct, (list(self),)
@@ -254,8 +275,7 @@ cdef class UnionType(DataType):
yield self[i]
def __getitem__(self, i):
- cdef int index = <int> _normalize_index(i, self.num_children)
- return pyarrow_wrap_field(self.type.child(index))
+ return self.child(i)
def __reduce__(self):
return union, (list(self), self.mode)