Repository: arrow Updated Branches: refs/heads/master 5bf6ae49e -> 74685f386
ARROW-427: [C++] Implement dictionary array type I thought some about this and thought that it made sense to store the reference to the dictionary values themselves in the data type object, similar to `CategoricalDtype` in pandas. This will be at least adequate for the Feather file format merge. In the IPC metadata, there is no explicit dictionary type -- an array can be dictionary encoded or not. On JIRA we've discussed adding a dictionary type flag indicating whether or not the dictionary values/categories are ordered (also called "ordinal") or unordered (also called "nominal"). That hasn't been done yet. Author: Wes McKinney <wes.mckin...@twosigma.com> Closes #268 from wesm/ARROW-427 and squashes the following commits: 5ce3701 [Wes McKinney] cpplint a6c2896 [Wes McKinney] Revert T::Equals(const T& other) to EqualsExact to appease clang 9a4edb5 [Wes McKinney] Implement rudimentary DictionaryArray::Validate 9efe46b [Wes McKinney] Add tests, implementation for DictionaryArray::Equals and RangeEquals b06eb86 [Wes McKinney] Implement PrettyPrint for DictionaryArray 17c70de [Wes McKinney] Refactor, compose shared_ptr<DataType> in DictionaryType b52b3a7 [Wes McKinney] Add rudimentary DictionaryType and DictionaryArray implementation for discussion Project: http://git-wip-us.apache.org/repos/asf/arrow/repo Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/74685f38 Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/74685f38 Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/74685f38 Branch: refs/heads/master Commit: 74685f386307171a90a9f97316e25b7f39cdd0a1 Parents: 5bf6ae4 Author: Wes McKinney <wes.mckin...@twosigma.com> Authored: Fri Jan 6 11:11:43 2017 -0500 Committer: Wes McKinney <wes.mckin...@twosigma.com> Committed: Fri Jan 6 11:11:43 2017 -0500 ---------------------------------------------------------------------- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/array-dictionary-test.cc | 128 ++++++++++++++++++++++ cpp/src/arrow/array-string-test.cc | 4 +- cpp/src/arrow/array.cc | 94 +++++++++++++--- cpp/src/arrow/array.h | 111 ++++++++++++++++--- cpp/src/arrow/ipc/adapter.cc | 11 ++ cpp/src/arrow/ipc/json-internal.cc | 13 +++ cpp/src/arrow/pretty_print-test.cc | 53 +++++---- cpp/src/arrow/pretty_print.cc | 12 ++ cpp/src/arrow/test-util.h | 36 +++--- cpp/src/arrow/type.cc | 69 ++++++++++-- cpp/src/arrow/type.h | 163 +++++++++++++++++++++------- cpp/src/arrow/type_fwd.h | 57 +--------- format/Message.fbs | 2 +- python/pyarrow/includes/libarrow.pxd | 3 +- python/pyarrow/includes/parquet.pxd | 2 +- python/pyarrow/parquet.pyx | 4 +- python/pyarrow/schema.pyx | 4 +- 18 files changed, 583 insertions(+), 184 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/CMakeLists.txt ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 16668db..e5e36ed 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -47,6 +47,7 @@ install( ADD_ARROW_TEST(array-test) ADD_ARROW_TEST(array-decimal-test) +ADD_ARROW_TEST(array-dictionary-test) ADD_ARROW_TEST(array-list-test) ADD_ARROW_TEST(array-primitive-test) ADD_ARROW_TEST(array-string-test) http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/array-dictionary-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array-dictionary-test.cc b/cpp/src/arrow/array-dictionary-test.cc new file mode 100644 index 0000000..c290153 --- /dev/null +++ b/cpp/src/arrow/array-dictionary-test.cc @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <cstdint> +#include <cstdlib> +#include <memory> +#include <numeric> +#include <vector> + +#include "gtest/gtest.h" + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/memory_pool.h" +#include "arrow/test-util.h" +#include "arrow/type.h" + +namespace arrow { + +TEST(TestDictionary, Basics) { + std::vector<int32_t> values = {100, 1000, 10000, 100000}; + std::shared_ptr<Array> dict; + ArrayFromVector<Int32Type, int32_t>(int32(), values, &dict); + + std::shared_ptr<DictionaryType> type1 = + std::dynamic_pointer_cast<DictionaryType>(dictionary(int16(), dict)); + DictionaryType type2(int16(), dict); + + ASSERT_TRUE(int16()->Equals(type1->index_type())); + ASSERT_TRUE(type1->dictionary()->Equals(dict)); + + ASSERT_TRUE(int16()->Equals(type2.index_type())); + ASSERT_TRUE(type2.dictionary()->Equals(dict)); + + ASSERT_EQ("dictionary<int32, int16>", type1->ToString()); +} + +TEST(TestDictionary, Equals) { + std::vector<bool> is_valid = {true, true, false, true, true, true}; + + std::shared_ptr<Array> dict; + std::vector<std::string> dict_values = {"foo", "bar", "baz"}; + ArrayFromVector<StringType, std::string>(utf8(), dict_values, &dict); + std::shared_ptr<DataType> dict_type = dictionary(int16(), dict); + + std::shared_ptr<Array> dict2; + std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"}; + ArrayFromVector<StringType, std::string>(utf8(), dict2_values, &dict2); + std::shared_ptr<DataType> dict2_type = dictionary(int16(), dict2); + + std::shared_ptr<Array> indices; + std::vector<int16_t> indices_values = {1, 2, -1, 0, 2, 0}; + ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices_values, &indices); + + std::shared_ptr<Array> indices2; + std::vector<int16_t> indices2_values = {1, 2, 0, 0, 2, 0}; + ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices2_values, &indices2); + + std::shared_ptr<Array> indices3; + std::vector<int16_t> indices3_values = {1, 1, 0, 0, 2, 0}; + ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices3_values, &indices3); + + auto arr = std::make_shared<DictionaryArray>(dict_type, indices); + auto arr2 = std::make_shared<DictionaryArray>(dict_type, indices2); + auto arr3 = std::make_shared<DictionaryArray>(dict2_type, indices); + auto arr4 = std::make_shared<DictionaryArray>(dict_type, indices3); + + ASSERT_TRUE(arr->Equals(arr)); + + // Equal, because the unequal index is masked by null + ASSERT_TRUE(arr->Equals(arr2)); + + // Unequal dictionaries + ASSERT_FALSE(arr->Equals(arr3)); + + // Unequal indices + ASSERT_FALSE(arr->Equals(arr4)); + + // RangeEquals + ASSERT_TRUE(arr->RangeEquals(3, 6, 3, arr4)); + ASSERT_FALSE(arr->RangeEquals(1, 3, 1, arr4)); +} + +TEST(TestDictionary, Validate) { + std::vector<bool> is_valid = {true, true, false, true, true, true}; + + std::shared_ptr<Array> dict; + std::vector<std::string> dict_values = {"foo", "bar", "baz"}; + ArrayFromVector<StringType, std::string>(utf8(), dict_values, &dict); + std::shared_ptr<DataType> dict_type = dictionary(int16(), dict); + + std::shared_ptr<Array> indices; + std::vector<uint8_t> indices_values = {1, 2, 0, 0, 2, 0}; + ArrayFromVector<UInt8Type, uint8_t>(uint8(), is_valid, indices_values, &indices); + + std::shared_ptr<Array> indices2; + std::vector<float> indices2_values = {1., 2., 0., 0., 2., 0.}; + ArrayFromVector<FloatType, float>(float32(), is_valid, indices2_values, &indices2); + + std::shared_ptr<Array> indices3; + std::vector<int64_t> indices3_values = {1, 2, 0, 0, 2, 0}; + ArrayFromVector<Int64Type, int64_t>(int64(), is_valid, indices3_values, &indices3); + + std::shared_ptr<Array> arr = std::make_shared<DictionaryArray>(dict_type, indices); + std::shared_ptr<Array> arr2 = std::make_shared<DictionaryArray>(dict_type, indices2); + std::shared_ptr<Array> arr3 = std::make_shared<DictionaryArray>(dict_type, indices3); + + // Only checking index type for now + ASSERT_OK(arr->Validate()); + ASSERT_RAISES(Invalid, arr2->Validate()); + ASSERT_OK(arr3->Validate()); +} + +} // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/array-string-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array-string-test.cc b/cpp/src/arrow/array-string-test.cc index b144c63..024bfd5 100644 --- a/cpp/src/arrow/array-string-test.cc +++ b/cpp/src/arrow/array-string-test.cc @@ -36,8 +36,8 @@ TEST(TypesTest, BinaryType) { BinaryType t1; BinaryType e1; StringType t2; - EXPECT_TRUE(t1.Equals(&e1)); - EXPECT_FALSE(t1.Equals(&t2)); + EXPECT_TRUE(t1.Equals(e1)); + EXPECT_FALSE(t1.Equals(t2)); ASSERT_EQ(t1.type, Type::BINARY); ASSERT_EQ(t1.ToString(), std::string("binary")); } http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/array.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 3d309b8..7509520 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -42,7 +42,7 @@ Status GetEmptyBitmap( // ---------------------------------------------------------------------- // Base array class -Array::Array(const TypePtr& type, int32_t length, int32_t null_count, +Array::Array(const std::shared_ptr<DataType>& type, int32_t length, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap) { type_ = type; length_ = length; @@ -51,6 +51,12 @@ Array::Array(const TypePtr& type, int32_t length, int32_t null_count, if (null_bitmap_) { null_bitmap_data_ = null_bitmap_->data(); } } +bool Array::BaseEquals(const std::shared_ptr<Array>& other) const { + if (this == other.get()) { return true; } + if (!other) { return false; } + return EqualsExact(*other.get()); +} + bool Array::EqualsExact(const Array& other) const { if (this == &other) { return true; } if (length_ != other.length_ || null_count_ != other.null_count_ || @@ -91,7 +97,7 @@ Status NullArray::Accept(ArrayVisitor* visitor) const { // ---------------------------------------------------------------------- // Primitive array base -PrimitiveArray::PrimitiveArray(const TypePtr& type, int32_t length, +PrimitiveArray::PrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length, const std::shared_ptr<Buffer>& data, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap) : Array(type, length, null_count, null_bitmap) { @@ -100,14 +106,9 @@ PrimitiveArray::PrimitiveArray(const TypePtr& type, int32_t length, } bool PrimitiveArray::EqualsExact(const PrimitiveArray& other) const { - if (this == &other) { return true; } - if (null_count_ != other.null_count_) { return false; } + if (!Array::EqualsExact(other)) { return false; } if (null_count_ > 0) { - bool equal_bitmap = - null_bitmap_->Equals(*other.null_bitmap_, BitUtil::CeilByte(length_) / 8); - if (!equal_bitmap) { return false; } - const uint8_t* this_data = raw_data_; const uint8_t* other_data = other.raw_data_; @@ -131,7 +132,7 @@ bool PrimitiveArray::Equals(const std::shared_ptr<Array>& arr) const { if (this == arr.get()) { return true; } if (!arr) { return false; } if (this->type_enum() != arr->type_enum()) { return false; } - return EqualsExact(*static_cast<const PrimitiveArray*>(arr.get())); + return EqualsExact(static_cast<const PrimitiveArray&>(*arr.get())); } template <typename T> @@ -161,7 +162,7 @@ BooleanArray::BooleanArray(int32_t length, const std::shared_ptr<Buffer>& data, : PrimitiveArray( std::make_shared<BooleanType>(), length, data, null_count, null_bitmap) {} -BooleanArray::BooleanArray(const TypePtr& type, int32_t length, +BooleanArray::BooleanArray(const std::shared_ptr<DataType>& type, int32_t length, const std::shared_ptr<Buffer>& data, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap) : PrimitiveArray(type, length, data, null_count, null_bitmap) {} @@ -192,7 +193,7 @@ bool BooleanArray::EqualsExact(const BooleanArray& other) const { bool BooleanArray::Equals(const std::shared_ptr<Array>& arr) const { if (this == arr.get()) return true; if (Type::BOOL != arr->type_enum()) { return false; } - return EqualsExact(*static_cast<const BooleanArray*>(arr.get())); + return EqualsExact(static_cast<const BooleanArray&>(*arr.get())); } bool BooleanArray::RangeEquals(int32_t start_idx, int32_t end_idx, @@ -238,7 +239,7 @@ bool ListArray::EqualsExact(const ListArray& other) const { bool ListArray::Equals(const std::shared_ptr<Array>& arr) const { if (this == arr.get()) { return true; } if (this->type_enum() != arr->type_enum()) { return false; } - return EqualsExact(*static_cast<const ListArray*>(arr.get())); + return EqualsExact(static_cast<const ListArray&>(*arr.get())); } bool ListArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx, @@ -333,7 +334,7 @@ BinaryArray::BinaryArray(int32_t length, const std::shared_ptr<Buffer>& offsets, const std::shared_ptr<Buffer>& null_bitmap) : BinaryArray(kBinary, length, offsets, data, null_count, null_bitmap) {} -BinaryArray::BinaryArray(const TypePtr& type, int32_t length, +BinaryArray::BinaryArray(const std::shared_ptr<DataType>& type, int32_t length, const std::shared_ptr<Buffer>& offsets, const std::shared_ptr<Buffer>& data, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap) : Array(type, length, null_count, null_bitmap), @@ -364,7 +365,7 @@ bool BinaryArray::EqualsExact(const BinaryArray& other) const { bool BinaryArray::Equals(const std::shared_ptr<Array>& arr) const { if (this == arr.get()) { return true; } if (this->type_enum() != arr->type_enum()) { return false; } - return EqualsExact(*static_cast<const BinaryArray*>(arr.get())); + return EqualsExact(static_cast<const BinaryArray&>(*arr.get())); } bool BinaryArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx, @@ -493,7 +494,7 @@ Status StructArray::Accept(ArrayVisitor* visitor) const { // ---------------------------------------------------------------------- // UnionArray -UnionArray::UnionArray(const TypePtr& type, int32_t length, +UnionArray::UnionArray(const std::shared_ptr<DataType>& type, int32_t length, const std::vector<std::shared_ptr<Array>>& children, const std::shared_ptr<Buffer>& type_ids, const std::shared_ptr<Buffer>& offsets, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap) @@ -587,13 +588,73 @@ Status UnionArray::Accept(ArrayVisitor* visitor) const { } // ---------------------------------------------------------------------- +// DictionaryArray + +Status DictionaryArray::FromBuffer(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& indices, int32_t null_count, + const std::shared_ptr<Buffer>& null_bitmap, std::shared_ptr<DictionaryArray>* out) { + DCHECK_EQ(type->type, Type::DICTIONARY); + const auto& dict_type = static_cast<const DictionaryType*>(type.get()); + + std::shared_ptr<Array> boxed_indices; + RETURN_NOT_OK(MakePrimitiveArray( + dict_type->index_type(), length, indices, null_count, null_bitmap, &boxed_indices)); + + *out = std::make_shared<DictionaryArray>(type, boxed_indices); + return Status::OK(); +} + +DictionaryArray::DictionaryArray( + const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& indices) + : Array(type, indices->length(), indices->null_count(), indices->null_bitmap()), + dict_type_(static_cast<const DictionaryType*>(type.get())), + indices_(indices) { + DCHECK_EQ(type->type, Type::DICTIONARY); +} + +Status DictionaryArray::Validate() const { + Type::type index_type_id = indices_->type()->type; + if (!is_integer(index_type_id)) { + return Status::Invalid("Dictionary indices must be integer type"); + } + return Status::OK(); +} + +std::shared_ptr<Array> DictionaryArray::dictionary() const { + return dict_type_->dictionary(); +} + +bool DictionaryArray::EqualsExact(const DictionaryArray& other) const { + if (!dictionary()->Equals(other.dictionary())) { return false; } + return indices_->Equals(other.indices()); +} + +bool DictionaryArray::Equals(const std::shared_ptr<Array>& arr) const { + if (this == arr.get()) { return true; } + if (Type::DICTIONARY != arr->type_enum()) { return false; } + return EqualsExact(static_cast<const DictionaryArray&>(*arr.get())); +} + +bool DictionaryArray::RangeEquals(int32_t start_idx, int32_t end_idx, + int32_t other_start_idx, const std::shared_ptr<Array>& arr) const { + if (Type::DICTIONARY != arr->type_enum()) { return false; } + const auto& dict_other = static_cast<const DictionaryArray&>(*arr.get()); + if (!dictionary()->Equals(dict_other.dictionary())) { return false; } + return indices_->RangeEquals(start_idx, end_idx, other_start_idx, dict_other.indices()); +} + +Status DictionaryArray::Accept(ArrayVisitor* visitor) const { + return visitor->Visit(*this); +} + +// ---------------------------------------------------------------------- #define MAKE_PRIMITIVE_ARRAY_CASE(ENUM, ArrayType) \ case Type::ENUM: \ out->reset(new ArrayType(type, length, data, null_count, null_bitmap)); \ break; -Status MakePrimitiveArray(const TypePtr& type, int32_t length, +Status MakePrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length, const std::shared_ptr<Buffer>& data, int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap, std::shared_ptr<Array>* out) { switch (type->type) { @@ -610,7 +671,6 @@ Status MakePrimitiveArray(const TypePtr& type, int32_t length, MAKE_PRIMITIVE_ARRAY_CASE(DOUBLE, DoubleArray); MAKE_PRIMITIVE_ARRAY_CASE(TIME, Int64Array); MAKE_PRIMITIVE_ARRAY_CASE(TIMESTAMP, TimestampArray); - MAKE_PRIMITIVE_ARRAY_CASE(TIMESTAMP_DOUBLE, DoubleArray); default: return Status::NotImplemented(type->ToString()); } http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/array.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index cd42a28..57214c4 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -26,6 +26,7 @@ #include "arrow/buffer.h" #include "arrow/type.h" +#include "arrow/type_fwd.h" #include "arrow/util/bit-util.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" @@ -36,6 +37,34 @@ class MemoryPool; class MutableBuffer; class Status; +class ArrayVisitor { + public: + virtual Status Visit(const NullArray& array) = 0; + virtual Status Visit(const BooleanArray& array) = 0; + virtual Status Visit(const Int8Array& array) = 0; + virtual Status Visit(const Int16Array& array) = 0; + virtual Status Visit(const Int32Array& array) = 0; + virtual Status Visit(const Int64Array& array) = 0; + virtual Status Visit(const UInt8Array& array) = 0; + virtual Status Visit(const UInt16Array& array) = 0; + virtual Status Visit(const UInt32Array& array) = 0; + virtual Status Visit(const UInt64Array& array) = 0; + virtual Status Visit(const HalfFloatArray& array) = 0; + virtual Status Visit(const FloatArray& array) = 0; + virtual Status Visit(const DoubleArray& array) = 0; + virtual Status Visit(const StringArray& array) = 0; + virtual Status Visit(const BinaryArray& array) = 0; + virtual Status Visit(const DateArray& array) = 0; + virtual Status Visit(const TimeArray& array) = 0; + virtual Status Visit(const TimestampArray& array) = 0; + virtual Status Visit(const IntervalArray& array) = 0; + virtual Status Visit(const DecimalArray& array) = 0; + virtual Status Visit(const ListArray& array) = 0; + virtual Status Visit(const StructArray& array) = 0; + virtual Status Visit(const UnionArray& array) = 0; + virtual Status Visit(const DictionaryArray& type) = 0; +}; + // Immutable data array with some logical type and some length. Any memory is // owned by the respective Buffer instance (or its parents). // @@ -63,6 +92,7 @@ class ARROW_EXPORT Array { const uint8_t* null_bitmap_data() const { return null_bitmap_data_; } + bool BaseEquals(const std::shared_ptr<Array>& arr) const; bool EqualsExact(const Array& arr) const; virtual bool Equals(const std::shared_ptr<Array>& arr) const = 0; virtual bool ApproxEquals(const std::shared_ptr<Array>& arr) const; @@ -122,8 +152,9 @@ class ARROW_EXPORT PrimitiveArray : public Array { bool Equals(const std::shared_ptr<Array>& arr) const override; protected: - PrimitiveArray(const TypePtr& type, int32_t length, const std::shared_ptr<Buffer>& data, - int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr); + PrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& data, int32_t null_count = 0, + const std::shared_ptr<Buffer>& null_bitmap = nullptr); std::shared_ptr<Buffer> data_; const uint8_t* raw_data_; }; @@ -137,8 +168,9 @@ class ARROW_EXPORT NumericArray : public PrimitiveArray { int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr) : PrimitiveArray( std::make_shared<TypeClass>(), length, data, null_count, null_bitmap) {} - NumericArray(const TypePtr& type, int32_t length, const std::shared_ptr<Buffer>& data, - int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr) + NumericArray(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& data, int32_t null_count = 0, + const std::shared_ptr<Buffer>& null_bitmap = nullptr) : PrimitiveArray(type, length, data, null_count, null_bitmap) {} bool EqualsExact(const NumericArray<TypeClass>& other) const { @@ -146,7 +178,7 @@ class ARROW_EXPORT NumericArray : public PrimitiveArray { } bool ApproxEquals(const std::shared_ptr<Array>& arr) const override { - return Equals(arr); + return PrimitiveArray::Equals(arr); } bool RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx, @@ -250,8 +282,9 @@ class ARROW_EXPORT BooleanArray : public PrimitiveArray { BooleanArray(int32_t length, const std::shared_ptr<Buffer>& data, int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr); - BooleanArray(const TypePtr& type, int32_t length, const std::shared_ptr<Buffer>& data, - int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr); + BooleanArray(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& data, int32_t null_count = 0, + const std::shared_ptr<Buffer>& null_bitmap = nullptr); bool EqualsExact(const BooleanArray& other) const; bool Equals(const std::shared_ptr<Array>& arr) const override; @@ -272,9 +305,9 @@ class ARROW_EXPORT ListArray : public Array { public: using TypeClass = ListType; - ListArray(const TypePtr& type, int32_t length, const std::shared_ptr<Buffer>& offsets, - const std::shared_ptr<Array>& values, int32_t null_count = 0, - const std::shared_ptr<Buffer>& null_bitmap = nullptr) + ListArray(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& offsets, const std::shared_ptr<Array>& values, + int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr) : Array(type, length, null_count, null_bitmap) { offsets_buffer_ = offsets; offsets_ = offsets == nullptr ? nullptr : reinterpret_cast<const int32_t*>( @@ -328,9 +361,9 @@ class ARROW_EXPORT BinaryArray : public Array { // Constructor that allows sub-classes/builders to propagate there logical type up the // class hierarchy. - BinaryArray(const TypePtr& type, int32_t length, const std::shared_ptr<Buffer>& offsets, - const std::shared_ptr<Buffer>& data, int32_t null_count = 0, - const std::shared_ptr<Buffer>& null_bitmap = nullptr); + BinaryArray(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& offsets, const std::shared_ptr<Buffer>& data, + int32_t null_count = 0, const std::shared_ptr<Buffer>& null_bitmap = nullptr); // Return the pointer to the given elements bytes // TODO(emkornfield) introduce a StringPiece or something similar to capture zero-copy @@ -397,7 +430,7 @@ class ARROW_EXPORT StructArray : public Array { public: using TypeClass = StructType; - StructArray(const TypePtr& type, int32_t length, + StructArray(const std::shared_ptr<DataType>& type, int32_t length, const std::vector<std::shared_ptr<Array>>& field_arrays, int32_t null_count = 0, std::shared_ptr<Buffer> null_bitmap = nullptr) : Array(type, length, null_count, null_bitmap) { @@ -434,7 +467,7 @@ class ARROW_EXPORT UnionArray : public Array { public: using TypeClass = UnionType; - UnionArray(const TypePtr& type, int32_t length, + UnionArray(const std::shared_ptr<DataType>& type, int32_t length, const std::vector<std::shared_ptr<Array>>& children, const std::shared_ptr<Buffer>& type_ids, const std::shared_ptr<Buffer>& offsets = nullptr, int32_t null_count = 0, @@ -474,6 +507,54 @@ class ARROW_EXPORT UnionArray : public Array { }; // ---------------------------------------------------------------------- +// DictionaryArray (categorical and dictionary-encoded in memory) + +// A dictionary array contains an array of non-negative integers (the +// "dictionary indices") along with a data type containing a "dictionary" +// corresponding to the distinct values represented in the data. +// +// For example, the array +// +// ["foo", "bar", "foo", "bar", "foo", "bar"] +// +// with dictionary ["bar", "foo"], would have dictionary array representation +// +// indices: [1, 0, 1, 0, 1, 0] +// dictionary: ["bar", "foo"] +// +// The indices in principle may have any integer type (signed or unsigned), +// though presently data in IPC exchanges must be signed int32. +class ARROW_EXPORT DictionaryArray : public Array { + public: + using TypeClass = DictionaryType; + + DictionaryArray( + const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& indices); + + // Alternate ctor; other attributes (like null count) are inherited from the + // passed indices array + static Status FromBuffer(const std::shared_ptr<DataType>& type, int32_t length, + const std::shared_ptr<Buffer>& indices, int32_t null_count, + const std::shared_ptr<Buffer>& null_bitmap, std::shared_ptr<DictionaryArray>* out); + + Status Validate() const override; + + std::shared_ptr<Array> indices() const { return indices_; } + std::shared_ptr<Array> dictionary() const; + + bool EqualsExact(const DictionaryArray& other) const; + bool Equals(const std::shared_ptr<Array>& arr) const override; + bool RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx, + const std::shared_ptr<Array>& arr) const override; + + Status Accept(ArrayVisitor* visitor) const override; + + protected: + const DictionaryType* dict_type_; + std::shared_ptr<Array> indices_; +}; + +// ---------------------------------------------------------------------- // extern templates and other details // gcc and clang disagree about how to handle template visibility when you have http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/ipc/adapter.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/adapter.cc b/cpp/src/arrow/ipc/adapter.cc index 9bfd11f..2b5ef11 100644 --- a/cpp/src/arrow/ipc/adapter.cc +++ b/cpp/src/arrow/ipc/adapter.cc @@ -288,6 +288,13 @@ class RecordBatchWriter : public ArrayVisitor { return Status::OK(); } + Status Visit(const DictionaryArray& array) override { + // Dictionary written out separately + const auto& indices = static_cast<const PrimitiveArray&>(*array.indices().get()); + buffers_.push_back(indices.data()); + return Status::OK(); + } + // Do not copy this vector. Ownership must be retained elsewhere const std::vector<std::shared_ptr<Array>>& columns_; int32_t num_rows_; @@ -539,6 +546,10 @@ class ArrayLoader : public TypeVisitor { type_ids, offsets, field_meta.null_count, null_bitmap); return Status::OK(); } + + Status Visit(const DictionaryType& type) override { + return Status::NotImplemented("dictionary"); + }; }; class RecordBatchReader { http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/ipc/json-internal.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 4f980d3..43bd8a4 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -334,6 +334,14 @@ class JsonSchemaWriter : public TypeVisitor { return Status::OK(); } + Status Visit(const DictionaryType& type) override { + // WriteName("dictionary", type); + // WriteChildren(type.children()); + // WriteBufferLayout(type.GetBufferLayout()); + // return Status::OK(); + return Status::NotImplemented("dictionary type"); + } + private: const Schema& schema_; RjWriter* writer_; @@ -546,6 +554,10 @@ class JsonArrayWriter : public ArrayVisitor { return WriteChildren(type->children(), array.children()); } + Status Visit(const DictionaryArray& array) override { + return Status::NotImplemented("dictionary"); + } + private: const std::string& name_; const Array& array_; @@ -1043,6 +1055,7 @@ class JsonArrayReader { TYPE_CASE(ListType); TYPE_CASE(StructType); TYPE_CASE(UnionType); + NOT_IMPLEMENTED_CASE(DICTIONARY); default: std::stringstream ss; ss << type->ToString(); http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/pretty_print-test.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/pretty_print-test.cc b/cpp/src/arrow/pretty_print-test.cc index c22d3aa..4725d5d 100644 --- a/cpp/src/arrow/pretty_print-test.cc +++ b/cpp/src/arrow/pretty_print-test.cc @@ -34,7 +34,7 @@ namespace arrow { -class TestArrayPrinter : public ::testing::Test { +class TestPrettyPrint : public ::testing::Test { public: void SetUp() {} @@ -44,32 +44,22 @@ class TestArrayPrinter : public ::testing::Test { std::ostringstream sink_; }; +void CheckArray(const Array& arr, int indent, const char* expected) { + std::ostringstream sink; + ASSERT_OK(PrettyPrint(arr, indent, &sink)); + std::string result = sink.str(); + ASSERT_EQ(std::string(expected, strlen(expected)), result); +} + template <typename TYPE, typename C_TYPE> void CheckPrimitive(int indent, const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values, const char* expected) { - std::ostringstream sink; - - MemoryPool* pool = default_memory_pool(); - typename TypeTraits<TYPE>::BuilderType builder(pool, std::make_shared<TYPE>()); - - for (size_t i = 0; i < values.size(); ++i) { - if (is_valid[i]) { - ASSERT_OK(builder.Append(values[i])); - } else { - ASSERT_OK(builder.AppendNull()); - } - } - std::shared_ptr<Array> array; - ASSERT_OK(builder.Finish(&array)); - - ASSERT_OK(PrettyPrint(*array.get(), indent, &sink)); - - std::string result = sink.str(); - ASSERT_EQ(std::string(expected, strlen(expected)), result); + ArrayFromVector<TYPE, C_TYPE>(std::make_shared<TYPE>(), is_valid, values, &array); + CheckArray(*array.get(), indent, expected); } -TEST_F(TestArrayPrinter, PrimitiveType) { +TEST_F(TestPrettyPrint, PrimitiveType) { std::vector<bool> is_valid = {true, true, false, true, false}; std::vector<int32_t> values = {0, 1, 2, 3, 4}; @@ -81,4 +71,25 @@ TEST_F(TestArrayPrinter, PrimitiveType) { CheckPrimitive<StringType, std::string>(0, is_valid, values2, ex2); } +TEST_F(TestPrettyPrint, DictionaryType) { + std::vector<bool> is_valid = {true, true, false, true, true, true}; + + std::shared_ptr<Array> dict; + std::vector<std::string> dict_values = {"foo", "bar", "baz"}; + ArrayFromVector<StringType, std::string>(utf8(), dict_values, &dict); + std::shared_ptr<DataType> dict_type = dictionary(int16(), dict); + + std::shared_ptr<Array> indices; + std::vector<int16_t> indices_values = {1, 2, -1, 0, 2, 0}; + ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices_values, &indices); + auto arr = std::make_shared<DictionaryArray>(dict_type, indices); + + static const char* expected = R"expected( +-- is_valid: [true, true, false, true, true, true] +-- dictionary: ["foo", "bar", "baz"] +-- indices: [1, 2, null, 0, 2, 0])expected"; + + CheckArray(*arr.get(), 0, expected); +} + } // namespace arrow http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/pretty_print.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index 324f81b..e30f4cc 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -217,6 +217,18 @@ class ArrayPrinter : public ArrayVisitor { return PrintChildren(array.children()); } + Status Visit(const DictionaryArray& array) override { + RETURN_NOT_OK(WriteValidityBitmap(array)); + + Newline(); + Write("-- dictionary: "); + RETURN_NOT_OK(PrettyPrint(*array.dictionary().get(), indent_ + 2, sink_)); + + Newline(); + Write("-- indices: "); + return PrettyPrint(*array.indices().get(), indent_ + 2, sink_); + } + void Write(const char* data) { (*sink_) << data; } void Write(const std::string& data) { (*sink_) << data; } http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/test-util.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/test-util.h b/cpp/src/arrow/test-util.h index 70e9333..e595749 100644 --- a/cpp/src/arrow/test-util.h +++ b/cpp/src/arrow/test-util.h @@ -257,33 +257,27 @@ template <typename TYPE, typename C_TYPE> void ArrayFromVector(const std::shared_ptr<DataType>& type, const std::vector<bool>& is_valid, const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) { - std::shared_ptr<Buffer> values_buffer; - std::shared_ptr<Buffer> values_bitmap; - - ASSERT_OK(test::CopyBufferFromVector(values, &values_buffer)); - ASSERT_OK(test::GetBitmapFromBoolVector(is_valid, &values_bitmap)); - - using ArrayType = typename TypeTraits<TYPE>::ArrayType; - - int32_t null_count = 0; - for (bool val : is_valid) { - if (!val) { ++null_count; } + MemoryPool* pool = default_memory_pool(); + typename TypeTraits<TYPE>::BuilderType builder(pool, std::make_shared<TYPE>()); + for (size_t i = 0; i < values.size(); ++i) { + if (is_valid[i]) { + ASSERT_OK(builder.Append(values[i])); + } else { + ASSERT_OK(builder.AppendNull()); + } } - - *out = std::make_shared<ArrayType>(type, static_cast<int32_t>(values.size()), - values_buffer, null_count, values_bitmap); + ASSERT_OK(builder.Finish(out)); } template <typename TYPE, typename C_TYPE> void ArrayFromVector(const std::shared_ptr<DataType>& type, const std::vector<C_TYPE>& values, std::shared_ptr<Array>* out) { - std::shared_ptr<Buffer> values_buffer; - - ASSERT_OK(test::CopyBufferFromVector(values, &values_buffer)); - - using ArrayType = typename TypeTraits<TYPE>::ArrayType; - *out = std::make_shared<ArrayType>( - type, static_cast<int32_t>(values.size()), values_buffer); + MemoryPool* pool = default_memory_pool(); + typename TypeTraits<TYPE>::BuilderType builder(pool, std::make_shared<TYPE>()); + for (size_t i = 0; i < values.size(); ++i) { + ASSERT_OK(builder.Append(values[i])); + } + ASSERT_OK(builder.Finish(out)); } class TestBuilder : public ::testing::Test { http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/type.cc ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 89faab6..954fba7 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -20,10 +20,22 @@ #include <sstream> #include <string> +#include "arrow/array.h" #include "arrow/status.h" +#include "arrow/util/logging.h" namespace arrow { +bool Field::Equals(const Field& other) const { + return (this == &other) || + (this->name == other.name && this->nullable == other.nullable && + this->dictionary == dictionary && this->type->Equals(*other.type.get())); +} + +bool Field::Equals(const std::shared_ptr<Field>& other) const { + return Equals(*other.get()); +} + std::string Field::ToString() const { std::stringstream ss; ss << this->name << ": " << this->type->ToString(); @@ -33,14 +45,14 @@ std::string Field::ToString() const { DataType::~DataType() {} -bool DataType::Equals(const DataType* other) const { - bool equals = other && ((this == other) || - ((this->type == other->type) && - ((this->num_children() == other->num_children())))); +bool DataType::Equals(const DataType& other) const { + bool equals = + ((this == &other) || ((this->type == other.type) && + ((this->num_children() == other.num_children())))); if (equals) { for (int i = 0; i < num_children(); ++i) { // TODO(emkornfield) limit recursion - if (!children_[i]->Equals(other->children_[i])) { return false; } + if (!children_[i]->Equals(other.children_[i])) { return false; } } } return equals; @@ -109,11 +121,47 @@ std::string UnionType::ToString() const { return s.str(); } +// ---------------------------------------------------------------------- +// DictionaryType + +DictionaryType::DictionaryType( + const std::shared_ptr<DataType>& index_type, const std::shared_ptr<Array>& dictionary) + : FixedWidthType(Type::DICTIONARY), + index_type_(index_type), + dictionary_(dictionary) {} + +int DictionaryType::bit_width() const { + return static_cast<const FixedWidthType*>(index_type_.get())->bit_width(); +} + +std::shared_ptr<Array> DictionaryType::dictionary() const { + return dictionary_; +} + +bool DictionaryType::Equals(const DataType& other) const { + if (other.type != Type::DICTIONARY) { return false; } + const auto& other_dict = static_cast<const DictionaryType&>(other); + + return index_type_->Equals(other_dict.index_type_) && + dictionary_->Equals(other_dict.dictionary_); +} + +std::string DictionaryType::ToString() const { + std::stringstream ss; + ss << "dictionary<" << dictionary_->type()->ToString() << ", " + << index_type_->ToString() << ">"; + return ss.str(); +} + +// ---------------------------------------------------------------------- +// Null type + std::string NullType::ToString() const { return name(); } -// Visitors and template instantiation +// ---------------------------------------------------------------------- +// Visitors and factory functions #define ACCEPT_VISITOR(TYPE) \ Status TYPE::Accept(TypeVisitor* visitor) const { return visitor->Visit(*this); } @@ -130,6 +178,7 @@ ACCEPT_VISITOR(DateType); ACCEPT_VISITOR(TimeType); ACCEPT_VISITOR(TimestampType); ACCEPT_VISITOR(IntervalType); +ACCEPT_VISITOR(DictionaryType); #define TYPE_FACTORY(NAME, KLASS) \ std::shared_ptr<DataType> NAME() { \ @@ -174,12 +223,16 @@ std::shared_ptr<DataType> struct_(const std::vector<std::shared_ptr<Field>>& fie return std::make_shared<StructType>(fields); } -std::shared_ptr<DataType> ARROW_EXPORT union_( - const std::vector<std::shared_ptr<Field>>& child_fields, +std::shared_ptr<DataType> union_(const std::vector<std::shared_ptr<Field>>& child_fields, const std::vector<uint8_t>& type_ids, UnionMode mode) { return std::make_shared<UnionType>(child_fields, type_ids, mode); } +std::shared_ptr<DataType> dictionary(const std::shared_ptr<DataType>& index_type, + const std::shared_ptr<Array>& dict_values) { + return std::make_shared<DictionaryType>(index_type, dict_values); +} + std::shared_ptr<Field> field( const std::string& name, const TypePtr& type, bool nullable, int64_t dictionary) { return std::make_shared<Field>(name, type, nullable, dictionary); http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/type.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 530c323..c2a762d 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -37,67 +37,64 @@ namespace arrow { struct Type { enum type { // A degenerate NULL type represented as 0 bytes/bits - NA = 0, + NA, // A boolean value represented as 1 bit - BOOL = 1, + BOOL, // Little-endian integer types - UINT8 = 2, - INT8 = 3, - UINT16 = 4, - INT16 = 5, - UINT32 = 6, - INT32 = 7, - UINT64 = 8, - INT64 = 9, + UINT8, + INT8, + UINT16, + INT16, + UINT32, + INT32, + UINT64, + INT64, // 2-byte floating point value - HALF_FLOAT = 10, + HALF_FLOAT, // 4-byte floating point value - FLOAT = 11, + FLOAT, // 8-byte floating point value - DOUBLE = 12, + DOUBLE, // UTF8 variable-length string as List<Char> - STRING = 13, + STRING, // Variable-length bytes (no guarantee of UTF8-ness) - BINARY = 14, + BINARY, // By default, int32 days since the UNIX epoch - DATE = 16, + DATE, // Exact timestamp encoded with int64 since UNIX epoch // Default unit millisecond - TIMESTAMP = 17, + TIMESTAMP, // Exact time encoded with int64, default unit millisecond - TIME = 18, + TIME, // YEAR_MONTH or DAY_TIME interval in SQL style - INTERVAL = 19, + INTERVAL, // Precision- and scale-based decimal type. Storage type depends on the // parameters. - DECIMAL = 20, + DECIMAL, // A list of some logical data type - LIST = 30, + LIST, // Struct of logical types - STRUCT = 31, + STRUCT, // Unions of logical types - UNION = 32, + UNION, - // Timestamp as double seconds since the UNIX epoch - TIMESTAMP_DOUBLE = 33, - - // Decimal value encoded as a text string - DECIMAL_TEXT = 34, + // Dictionary aka Category type + DICTIONARY }; }; @@ -115,6 +112,34 @@ class BufferDescr { int bit_width_; }; +class TypeVisitor { + public: + virtual Status Visit(const NullType& type) = 0; + virtual Status Visit(const BooleanType& type) = 0; + virtual Status Visit(const Int8Type& type) = 0; + virtual Status Visit(const Int16Type& type) = 0; + virtual Status Visit(const Int32Type& type) = 0; + virtual Status Visit(const Int64Type& type) = 0; + virtual Status Visit(const UInt8Type& type) = 0; + virtual Status Visit(const UInt16Type& type) = 0; + virtual Status Visit(const UInt32Type& type) = 0; + virtual Status Visit(const UInt64Type& type) = 0; + virtual Status Visit(const HalfFloatType& type) = 0; + virtual Status Visit(const FloatType& type) = 0; + virtual Status Visit(const DoubleType& type) = 0; + virtual Status Visit(const StringType& type) = 0; + virtual Status Visit(const BinaryType& type) = 0; + virtual Status Visit(const DateType& type) = 0; + virtual Status Visit(const TimeType& type) = 0; + virtual Status Visit(const TimestampType& type) = 0; + virtual Status Visit(const IntervalType& type) = 0; + virtual Status Visit(const DecimalType& type) = 0; + virtual Status Visit(const ListType& type) = 0; + virtual Status Visit(const StructType& type) = 0; + virtual Status Visit(const UnionType& type) = 0; + virtual Status Visit(const DictionaryType& type) = 0; +}; + struct ARROW_EXPORT DataType { Type::type type; @@ -128,10 +153,10 @@ struct ARROW_EXPORT DataType { // // Types that are logically convertable from one to another e.g. List<UInt8> // and Binary are NOT equal). - virtual bool Equals(const DataType* other) const; + virtual bool Equals(const DataType& other) const; bool Equals(const std::shared_ptr<DataType>& other) const { - return Equals(other.get()); + return Equals(*other.get()); } std::shared_ptr<Field> child(int i) const { return children_[i]; } @@ -189,16 +214,9 @@ struct ARROW_EXPORT Field { : name(name), type(type), nullable(nullable), dictionary(dictionary) {} bool operator==(const Field& other) const { return this->Equals(other); } - bool operator!=(const Field& other) const { return !this->Equals(other); } - - bool Equals(const Field& other) const { - return (this == &other) || - (this->name == other.name && this->nullable == other.nullable && - this->dictionary == dictionary && this->type->Equals(other.type.get())); - } - - bool Equals(const std::shared_ptr<Field>& other) const { return Equals(*other.get()); } + bool Equals(const Field& other) const; + bool Equals(const std::shared_ptr<Field>& other) const; std::string ToString() const; }; @@ -414,6 +432,9 @@ struct ARROW_EXPORT UnionType : public DataType { std::vector<uint8_t> type_ids; }; +// ---------------------------------------------------------------------- +// Date and time types + struct ARROW_EXPORT DateType : public FixedWidthType { static constexpr Type::type type_id = Type::DATE; @@ -488,6 +509,35 @@ struct ARROW_EXPORT IntervalType : public FixedWidthType { static std::string name() { return "date"; } }; +// ---------------------------------------------------------------------- +// DictionaryType (for categorical or dictionary-encoded data) + +class ARROW_EXPORT DictionaryType : public FixedWidthType { + public: + static constexpr Type::type type_id = Type::DICTIONARY; + + DictionaryType(const std::shared_ptr<DataType>& index_type, + const std::shared_ptr<Array>& dictionary); + + int bit_width() const override; + + std::shared_ptr<DataType> index_type() const { return index_type_; } + + std::shared_ptr<Array> dictionary() const; + + bool Equals(const DataType& other) const override; + + Status Accept(TypeVisitor* visitor) const override; + std::string ToString() const override; + + private: + // Must be an integer type (not currently checked) + std::shared_ptr<DataType> index_type_; + + std::shared_ptr<Array> dictionary_; +}; + +// ---------------------------------------------------------------------- // Factory functions std::shared_ptr<DataType> ARROW_EXPORT null(); @@ -520,9 +570,44 @@ std::shared_ptr<DataType> ARROW_EXPORT union_( const std::vector<std::shared_ptr<Field>>& child_fields, const std::vector<uint8_t>& type_ids, UnionMode mode = UnionMode::SPARSE); +std::shared_ptr<DataType> ARROW_EXPORT dictionary( + const std::shared_ptr<DataType>& index_type, const std::shared_ptr<Array>& values); + std::shared_ptr<Field> ARROW_EXPORT field(const std::string& name, const std::shared_ptr<DataType>& type, bool nullable = true, int64_t dictionary = 0); +// ---------------------------------------------------------------------- +// + +static inline bool is_integer(Type::type type_id) { + switch (type_id) { + case Type::UINT8: + case Type::INT8: + case Type::UINT16: + case Type::INT16: + case Type::UINT32: + case Type::INT32: + case Type::UINT64: + case Type::INT64: + return true; + default: + break; + } + return false; +} + +static inline bool is_floating(Type::type type_id) { + switch (type_id) { + case Type::HALF_FLOAT: + case Type::FLOAT: + case Type::DOUBLE: + return true; + default: + break; + } + return false; +} + } // namespace arrow #endif // ARROW_TYPE_H http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/cpp/src/arrow/type_fwd.h ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index a14c535..334abef 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -32,6 +32,9 @@ class MemoryPool; class RecordBatch; class Schema; +class DictionaryType; +class DictionaryArray; + struct NullType; class NullArray; @@ -101,60 +104,6 @@ using TimestampBuilder = NumericBuilder<TimestampType>; struct IntervalType; using IntervalArray = NumericArray<IntervalType>; -class TypeVisitor { - public: - virtual Status Visit(const NullType& type) = 0; - virtual Status Visit(const BooleanType& type) = 0; - virtual Status Visit(const Int8Type& type) = 0; - virtual Status Visit(const Int16Type& type) = 0; - virtual Status Visit(const Int32Type& type) = 0; - virtual Status Visit(const Int64Type& type) = 0; - virtual Status Visit(const UInt8Type& type) = 0; - virtual Status Visit(const UInt16Type& type) = 0; - virtual Status Visit(const UInt32Type& type) = 0; - virtual Status Visit(const UInt64Type& type) = 0; - virtual Status Visit(const HalfFloatType& type) = 0; - virtual Status Visit(const FloatType& type) = 0; - virtual Status Visit(const DoubleType& type) = 0; - virtual Status Visit(const StringType& type) = 0; - virtual Status Visit(const BinaryType& type) = 0; - virtual Status Visit(const DateType& type) = 0; - virtual Status Visit(const TimeType& type) = 0; - virtual Status Visit(const TimestampType& type) = 0; - virtual Status Visit(const IntervalType& type) = 0; - virtual Status Visit(const DecimalType& type) = 0; - virtual Status Visit(const ListType& type) = 0; - virtual Status Visit(const StructType& type) = 0; - virtual Status Visit(const UnionType& type) = 0; -}; - -class ArrayVisitor { - public: - virtual Status Visit(const NullArray& array) = 0; - virtual Status Visit(const BooleanArray& array) = 0; - virtual Status Visit(const Int8Array& array) = 0; - virtual Status Visit(const Int16Array& array) = 0; - virtual Status Visit(const Int32Array& array) = 0; - virtual Status Visit(const Int64Array& array) = 0; - virtual Status Visit(const UInt8Array& array) = 0; - virtual Status Visit(const UInt16Array& array) = 0; - virtual Status Visit(const UInt32Array& array) = 0; - virtual Status Visit(const UInt64Array& array) = 0; - virtual Status Visit(const HalfFloatArray& array) = 0; - virtual Status Visit(const FloatArray& array) = 0; - virtual Status Visit(const DoubleArray& array) = 0; - virtual Status Visit(const StringArray& array) = 0; - virtual Status Visit(const BinaryArray& array) = 0; - virtual Status Visit(const DateArray& array) = 0; - virtual Status Visit(const TimeArray& array) = 0; - virtual Status Visit(const TimestampArray& array) = 0; - virtual Status Visit(const IntervalArray& array) = 0; - virtual Status Visit(const DecimalArray& array) = 0; - virtual Status Visit(const ListArray& array) = 0; - virtual Status Visit(const StructArray& array) = 0; - virtual Status Visit(const UnionArray& array) = 0; -}; - } // namespace arrow #endif // ARROW_TYPE_FWD_H http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/format/Message.fbs ---------------------------------------------------------------------- diff --git a/format/Message.fbs b/format/Message.fbs index d07d066..b2c6464 100644 --- a/format/Message.fbs +++ b/format/Message.fbs @@ -256,7 +256,7 @@ table RecordBatch { /// For sending dictionary encoding information. Any Field can be /// dictionary-encoded, but in this case none of its children may be /// dictionary-encoded. -/// There is one dictionary batch per dictionary +/// There is one vector / column per dictionary /// table DictionaryBatch { http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/python/pyarrow/includes/libarrow.pxd ---------------------------------------------------------------------- diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 40fb60d..3cdfe49 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -55,7 +55,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDataType" arrow::DataType": Type type - c_bool Equals(const CDataType* other) + c_bool Equals(const shared_ptr[CDataType]& other) + c_bool Equals(const CDataType& other) c_string ToString() http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/python/pyarrow/includes/parquet.pxd ---------------------------------------------------------------------- diff --git a/python/pyarrow/includes/parquet.pxd b/python/pyarrow/includes/parquet.pxd index b4d127c..d9e121d 100644 --- a/python/pyarrow/includes/parquet.pxd +++ b/python/pyarrow/includes/parquet.pxd @@ -98,7 +98,7 @@ cdef extern from "parquet/api/reader.h" namespace "parquet" nogil: # TODO: Some default arguments are missing @staticmethod unique_ptr[ParquetFileReader] OpenFile(const c_string& path) - const FileMetaData* metadata(); + shared_ptr[FileMetaData] metadata(); cdef extern from "parquet/api/writer.h" namespace "parquet" nogil: http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/python/pyarrow/parquet.pyx ---------------------------------------------------------------------- diff --git a/python/pyarrow/parquet.pyx b/python/pyarrow/parquet.pyx index 7379456..c092185 100644 --- a/python/pyarrow/parquet.pyx +++ b/python/pyarrow/parquet.pyx @@ -98,8 +98,8 @@ cdef class ParquetReader: Integer index of the position of the column """ cdef: - const FileMetaData* metadata = (self.reader.get() - .parquet_reader().metadata()) + const FileMetaData* metadata = (self.reader.get().parquet_reader() + .metadata().get()) int i = 0 if self.column_idx_map is None: http://git-wip-us.apache.org/repos/asf/arrow/blob/74685f38/python/pyarrow/schema.pyx ---------------------------------------------------------------------- diff --git a/python/pyarrow/schema.pyx b/python/pyarrow/schema.pyx index 7a69b0f..d91ae7c 100644 --- a/python/pyarrow/schema.pyx +++ b/python/pyarrow/schema.pyx @@ -45,9 +45,9 @@ cdef class DataType: def __richcmp__(DataType self, DataType other, int op): if op == cpython.Py_EQ: - return self.type.Equals(other.type) + return self.type.Equals(other.sp_type) elif op == cpython.Py_NE: - return not self.type.Equals(other.type) + return not self.type.Equals(other.sp_type) else: raise TypeError('Invalid comparison')