This is an automated email from the ASF dual-hosted git repository.
jorisvandenbossche pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new aff876a572 GH-34796: [C++] Add FromTensor, ToTensor and strides
methods to FixedShapeTensorArray (#34797)
aff876a572 is described below
commit aff876a572db7a732fbafb0dbbf53f078bc79403
Author: Rok Mihevc <[email protected]>
AuthorDate: Tue Apr 11 18:15:20 2023 +0200
GH-34796: [C++] Add FromTensor, ToTensor and strides methods to
FixedShapeTensorArray (#34797)
### Rationale for this change
We want to enable converting Tensors to FixedShapeTensorArrays and the
other way around.
### What changes are included in this PR?
This adds FromTensor, ToTensor to FixedShapeTensorArrays and strides method
to FixedShapeTensorType.
### Are these changes tested?
Yes.
### Are there any user-facing changes?
This adds FromTensor, ToTensor and strides are user facing methods.
* Closes: #34796
Authored-by: Rok Mihevc <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
---
cpp/src/arrow/extension/fixed_shape_tensor.cc | 182 +++++++++++++++++
cpp/src/arrow/extension/fixed_shape_tensor.h | 26 +++
cpp/src/arrow/extension/fixed_shape_tensor_test.cc | 222 +++++++++++++++++++++
3 files changed, 430 insertions(+)
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index 8b0ed43df5..1debac0e70 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -23,6 +23,7 @@
#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
+#include "arrow/tensor.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
#include "arrow/util/sort.h"
@@ -33,8 +34,52 @@
namespace rj = arrow::rapidjson;
namespace arrow {
+
namespace extension {
+namespace {
+
+Status ComputeStrides(const FixedWidthType& type, const std::vector<int64_t>&
shape,
+ const std::vector<int64_t>& permutation,
+ std::vector<int64_t>* strides) {
+ if (permutation.empty()) {
+ return internal::ComputeRowMajorStrides(type, shape, strides);
+ }
+
+ const int byte_width = type.byte_width();
+
+ int64_t remaining = 0;
+ if (!shape.empty() && shape.front() > 0) {
+ remaining = byte_width;
+ for (auto i : permutation) {
+ if (i > 0) {
+ if (internal::MultiplyWithOverflow(remaining, shape[i], &remaining)) {
+ return Status::Invalid(
+ "Strides computed from shape would not fit in 64-bit integer");
+ }
+ }
+ }
+ }
+
+ if (remaining == 0) {
+ strides->assign(shape.size(), byte_width);
+ return Status::OK();
+ }
+
+ strides->push_back(remaining);
+ for (auto i : permutation) {
+ if (i > 0) {
+ remaining /= shape[i];
+ strides->push_back(remaining);
+ }
+ }
+ internal::Permute(permutation, strides);
+
+ return Status::OK();
+}
+
+} // namespace
+
bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
return false;
@@ -140,6 +185,132 @@ std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
return std::make_shared<ExtensionArray>(data);
}
+Result<std::shared_ptr<FixedShapeTensorArray>>
FixedShapeTensorArray::FromTensor(
+ const std::shared_ptr<Tensor>& tensor) {
+ auto permutation = internal::ArgSort(tensor->strides(), std::greater<>());
+ if (permutation[0] != 0) {
+ return Status::Invalid(
+ "Only first-major tensors can be zero-copy converted to arrays");
+ }
+ permutation.erase(permutation.begin());
+
+ std::vector<int64_t> cell_shape;
+ for (auto i : permutation) {
+ cell_shape.emplace_back(tensor->shape()[i]);
+ }
+
+ std::vector<std::string> dim_names;
+ if (!tensor->dim_names().empty()) {
+ for (auto i : permutation) {
+ dim_names.emplace_back(tensor->dim_names()[i]);
+ }
+ }
+
+ for (int64_t& i : permutation) {
+ --i;
+ }
+
+ auto ext_type = internal::checked_pointer_cast<ExtensionType>(
+ fixed_shape_tensor(tensor->type(), cell_shape, permutation, dim_names));
+
+ std::shared_ptr<Array> value_array;
+ switch (tensor->type_id()) {
+ case Type::UINT8: {
+ value_array = std::make_shared<UInt8Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::INT8: {
+ value_array = std::make_shared<Int8Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::UINT16: {
+ value_array = std::make_shared<UInt16Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::INT16: {
+ value_array = std::make_shared<Int16Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::UINT32: {
+ value_array = std::make_shared<UInt32Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::INT32: {
+ value_array = std::make_shared<Int32Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::UINT64: {
+ value_array = std::make_shared<Int64Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::INT64: {
+ value_array = std::make_shared<Int64Array>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::HALF_FLOAT: {
+ value_array = std::make_shared<HalfFloatArray>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::FLOAT: {
+ value_array = std::make_shared<FloatArray>(tensor->size(),
tensor->data());
+ break;
+ }
+ case Type::DOUBLE: {
+ value_array = std::make_shared<DoubleArray>(tensor->size(),
tensor->data());
+ break;
+ }
+ default: {
+ return Status::NotImplemented("Unsupported tensor type: ",
+ tensor->type()->ToString());
+ }
+ }
+ auto cell_size = static_cast<int32_t>(tensor->size() / tensor->shape()[0]);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> arr,
+ FixedSizeListArray::FromArrays(value_array,
cell_size));
+ std::shared_ptr<Array> ext_arr = ExtensionType::WrapArray(ext_type, arr);
+ return std::reinterpret_pointer_cast<FixedShapeTensorArray>(ext_arr);
+}
+
+const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const {
+ // To convert an array of n dimensional tensors to a n+1 dimensional tensor
we
+ // interpret the array's length as the first dimension the new tensor.
+
+ auto ext_arr =
internal::checked_pointer_cast<FixedSizeListArray>(this->storage());
+ auto ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(this->type());
+ ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()),
+ Status::Invalid(ext_arr->value_type()->ToString(),
+ " is not valid data type for a tensor"));
+ auto permutation = ext_type->permutation();
+
+ std::vector<std::string> dim_names;
+ if (!ext_type->dim_names().empty()) {
+ for (auto i : permutation) {
+ dim_names.emplace_back(ext_type->dim_names()[i]);
+ }
+ dim_names.insert(dim_names.begin(), 1, "");
+ } else {
+ dim_names = {};
+ }
+
+ std::vector<int64_t> shape;
+ for (int64_t& i : permutation) {
+ shape.emplace_back(ext_type->shape()[i]);
+ ++i;
+ }
+ shape.insert(shape.begin(), 1, this->length());
+ permutation.insert(permutation.begin(), 1, 0);
+
+ std::vector<int64_t> tensor_strides;
+ auto value_type =
internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type());
+ ARROW_RETURN_NOT_OK(
+ ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides));
+ ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten());
+ ARROW_ASSIGN_OR_RAISE(
+ auto tensor, Tensor::Make(ext_arr->value_type(),
buffers->data()->buffers[1], shape,
+ tensor_strides, dim_names));
+ return tensor;
+}
+
Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>&
shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>&
dim_names) {
@@ -157,6 +328,17 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Make(
shape, permutation, dim_names);
}
+const std::vector<int64_t>& FixedShapeTensorType::strides() {
+ if (strides_.empty()) {
+ auto value_type =
internal::checked_pointer_cast<FixedWidthType>(this->value_type_);
+ std::vector<int64_t> tensor_strides;
+ ARROW_CHECK_OK(ComputeStrides(*value_type.get(), this->shape(),
this->permutation(),
+ &tensor_strides));
+ strides_ = tensor_strides;
+ }
+ return strides_;
+}
+
std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>&
value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>&
permutation,
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.h
b/cpp/src/arrow/extension/fixed_shape_tensor.h
index 4ee2b894ee..93837f1300 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.h
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.h
@@ -23,6 +23,26 @@ namespace extension {
class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
+
+ /// \brief Create a FixedShapeTensorArray from a Tensor
+ ///
+ /// This method will create a FixedShapeTensorArray from a Tensor, taking
its first
+ /// dimension as the number of elements in the resulting array and the
remaining
+ /// dimensions as the shape of the individual tensors. If Tensor provides
strides,
+ /// they will be used to determine dimension permutation. Otherwise,
row-major layout
+ /// (i.e. no permutation) will be assumed.
+ ///
+ /// \param[in] tensor The Tensor to convert to a FixedShapeTensorArray
+ static Result<std::shared_ptr<FixedShapeTensorArray>> FromTensor(
+ const std::shared_ptr<Tensor>& tensor);
+
+ /// \brief Create a Tensor from FixedShapeTensorArray
+ ///
+ /// This method will create a Tensor from a FixedShapeTensorArray, setting
its first
+ /// dimension as length equal to the FixedShapeTensorArray's length and the
remaining
+ /// dimensions as the FixedShapeTensorType's shape. Shape and dim_names will
be
+ /// permuted according to permutation stored in the FixedShapeTensorType
metadata.
+ const Result<std::shared_ptr<Tensor>> ToTensor() const;
};
/// \brief Concrete type class for constant-size Tensor data.
@@ -51,6 +71,11 @@ class ARROW_EXPORT FixedShapeTensorType : public
ExtensionType {
/// Value type of tensor elements
const std::shared_ptr<DataType> value_type() const { return value_type_; }
+ /// Strides of tensor elements. Strides state offset in bytes between
adjacent
+ /// elements along each dimension. In case permutation is non-empty strides
are
+ /// computed from permuted tensor element's shape.
+ const std::vector<int64_t>& strides();
+
/// Permutation mapping from logical to physical memory layout of tensor
elements
const std::vector<int64_t>& permutation() const { return permutation_; }
@@ -78,6 +103,7 @@ class ARROW_EXPORT FixedShapeTensorType : public
ExtensionType {
std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
+ std::vector<int64_t> strides_;
std::vector<int64_t> permutation_;
std::vector<std::string> dim_names_;
};
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
index 16ba9d2014..50132e25fb 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor_test.cc
@@ -47,17 +47,26 @@ class TestExtensionType : public ::testing::Test {
fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_));
values_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
34, 35};
+ values_partial_ = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23};
+ shape_partial_ = {2, 3, 4};
+ tensor_strides_ = {96, 32, 8};
+ cell_strides_ = {32, 8};
serialized_ = R"({"shape":[3,4],"dim_names":["x","y"]})";
}
protected:
std::vector<int64_t> shape_;
+ std::vector<int64_t> shape_partial_;
std::vector<int64_t> cell_shape_;
std::shared_ptr<DataType> value_type_;
std::shared_ptr<DataType> cell_type_;
std::vector<std::string> dim_names_;
std::shared_ptr<ExtensionType> ext_type_;
std::vector<int64_t> values_;
+ std::vector<int64_t> values_partial_;
+ std::vector<int64_t> tensor_strides_;
+ std::vector<int64_t> cell_strides_;
std::string serialized_;
};
@@ -100,6 +109,7 @@ TEST_F(TestExtensionType, CreateExtensionType) {
ASSERT_EQ(exact_ext_type->ndim(), cell_shape_.size());
ASSERT_EQ(exact_ext_type->shape(), cell_shape_);
ASSERT_EQ(exact_ext_type->value_type(), value_type_);
+ ASSERT_EQ(exact_ext_type->strides(), cell_strides_);
ASSERT_EQ(exact_ext_type->dim_names(), dim_names_);
EXPECT_RAISES_WITH_MESSAGE_THAT(
@@ -212,4 +222,216 @@ TEST_F(TestExtensionType, RoudtripBatch) {
CompareBatch(*batch, *read_batch2, /*compare_metadata=*/true);
}
+TEST_F(TestExtensionType, CreateFromTensor) {
+ std::vector<int64_t> column_major_strides = {8, 24, 72};
+ std::vector<int64_t> neither_major_strides = {96, 8, 32};
+
+ ASSERT_OK_AND_ASSIGN(auto tensor,
+ Tensor::Make(value_type_, Buffer::Wrap(values_),
shape_));
+
+ auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+
+ ASSERT_OK(ext_arr->ValidateFull());
+ ASSERT_TRUE(tensor->is_row_major());
+ ASSERT_EQ(tensor->strides(), tensor_strides_);
+ ASSERT_EQ(ext_arr->length(), shape_[0]);
+
+ auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4}, {0, 1}));
+ ASSERT_OK_AND_ASSIGN(auto ext_arr_2,
FixedShapeTensorArray::FromTensor(tensor));
+
+ ASSERT_OK_AND_ASSIGN(
+ auto column_major_tensor,
+ Tensor::Make(value_type_, Buffer::Wrap(values_), shape_,
column_major_strides));
+ auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4}, {0, 1}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid,
+ testing::HasSubstr(
+ "Invalid: Only first-major tensors can be zero-copy converted to
arrays"),
+ FixedShapeTensorArray::FromTensor(column_major_tensor));
+ ASSERT_THAT(FixedShapeTensorArray::FromTensor(column_major_tensor),
+ Raises(StatusCode::Invalid));
+
+ auto neither_major_tensor = std::make_shared<Tensor>(value_type_,
Buffer::Wrap(values_),
+ shape_,
neither_major_strides);
+ auto ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4}, {1, 0}));
+ ASSERT_OK_AND_ASSIGN(auto ext_arr_4,
+
FixedShapeTensorArray::FromTensor(neither_major_tensor));
+
+ auto ext_type_5 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(binary(), {1, 3}));
+ auto arr = ArrayFromJSON(binary(), R"(["abc", "def"])");
+
+ ASSERT_OK_AND_ASSIGN(auto fsla_arr,
+ FixedSizeListArray::FromArrays(arr,
fixed_size_list(binary(), 2)));
+ auto ext_arr_5 = std::reinterpret_pointer_cast<FixedShapeTensorArray>(
+ ExtensionType::WrapArray(ext_type_5, fsla_arr));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("binary is not valid data type for a
tensor"),
+ ext_arr_5->ToTensor());
+
+ auto ext_type_6 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {1, 2}));
+ auto arr_with_null = ArrayFromJSON(int64(), "[1, 0, null, null, 1, 2]");
+ ASSERT_OK_AND_ASSIGN(auto fsla_arr_6, FixedSizeListArray::FromArrays(
+ arr_with_null,
fixed_size_list(int64(), 2)));
+}
+
+void CheckFromTensorType(const std::shared_ptr<Tensor>& tensor,
+ std::shared_ptr<DataType> expected_ext_type) {
+ auto ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(expected_ext_type);
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+ auto generated_ext_type =
+ internal::checked_cast<const
FixedShapeTensorType*>(ext_arr->extension_type());
+
+ // Check that generated type is equal to the expected type
+ ASSERT_EQ(generated_ext_type->type_name(), ext_type->type_name());
+ ASSERT_EQ(generated_ext_type->shape(), ext_type->shape());
+ ASSERT_EQ(generated_ext_type->dim_names(), ext_type->dim_names());
+ ASSERT_EQ(generated_ext_type->permutation(), ext_type->permutation());
+
ASSERT_TRUE(generated_ext_type->storage_type()->Equals(*ext_type->storage_type()));
+ ASSERT_TRUE(generated_ext_type->Equals(ext_type));
+}
+
+TEST_F(TestExtensionType, TestFromTensorType) {
+ auto values = Buffer::Wrap(values_);
+ auto shapes =
+ std::vector<std::vector<int64_t>>{{3, 3, 4}, {3, 3, 4}, {3, 4, 3}, {3,
4, 3}};
+ auto strides = std::vector<std::vector<int64_t>>{
+ {96, 32, 8}, {96, 8, 24}, {96, 24, 8}, {96, 8, 32}};
+ auto tensor_dim_names = std::vector<std::vector<std::string>>{
+ {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"},
+ {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}};
+ auto dim_names = std::vector<std::vector<std::string>>{
+ {"y", "z"}, {"z", "y"}, {"y", "z"}, {"z", "y"},
+ {"y", "z"}, {"y", "z"}, {"y", "z"}, {"y", "z"}};
+ auto cell_shapes = std::vector<std::vector<int64_t>>{{3, 4}, {4, 3}, {4, 3},
{3, 4}};
+ auto permutations = std::vector<std::vector<int64_t>>{{0, 1}, {1, 0}, {0,
1}, {1, 0}};
+
+ for (size_t i = 0; i < shapes.size(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values,
shapes[i],
+ strides[i],
tensor_dim_names[i]));
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+ auto ext_type =
+ fixed_shape_tensor(value_type_, cell_shapes[i], permutations[i],
dim_names[i]);
+ CheckFromTensorType(tensor, ext_type);
+ }
+}
+
+void CheckTensorRoundtrip(const std::shared_ptr<Tensor>& tensor) {
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+ ASSERT_OK_AND_ASSIGN(auto tensor_from_array, ext_arr->ToTensor());
+
+ ASSERT_EQ(tensor->type(), tensor_from_array->type());
+ ASSERT_EQ(tensor->shape(), tensor_from_array->shape());
+ for (size_t i = 1; i < tensor->dim_names().size(); i++) {
+ ASSERT_EQ(tensor->dim_names()[i], tensor_from_array->dim_names()[i]);
+ }
+ ASSERT_EQ(tensor->strides(), tensor_from_array->strides());
+ ASSERT_TRUE(tensor->data()->Equals(*tensor_from_array->data()));
+ ASSERT_TRUE(tensor->Equals(*tensor_from_array));
+}
+
+TEST_F(TestExtensionType, RoundtripTensor) {
+ auto values = Buffer::Wrap(values_);
+
+ auto shapes = std::vector<std::vector<int64_t>>{
+ {3, 3, 4}, {3, 4, 3}, {3, 4, 3}, {3, 3, 4}, {6, 2, 3},
+ {6, 3, 2}, {2, 3, 6}, {2, 6, 3}, {2, 3, 2, 3}, {2, 3, 2, 3}};
+ auto strides = std::vector<std::vector<int64_t>>{
+ {96, 32, 8}, {96, 8, 32}, {96, 24, 8}, {96, 8, 24}, {48, 24, 8},
+ {48, 8, 24}, {144, 48, 8}, {144, 8, 48}, {144, 48, 24, 8}, {144, 8, 24,
48}};
+ auto tensor_dim_names = std::vector<std::vector<std::string>>{
+ {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y",
"z"},
+ {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y", "z"}, {"x", "y",
"z"},
+ {"N", "H", "W", "C"}, {"N", "H", "W", "C"}};
+
+ for (size_t i = 0; i < shapes.size(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_, values,
shapes[i],
+ strides[i],
tensor_dim_names[i]));
+ CheckTensorRoundtrip(tensor);
+ }
+}
+
+TEST_F(TestExtensionType, SliceTensor) {
+ ASSERT_OK_AND_ASSIGN(auto tensor,
+ Tensor::Make(value_type_, Buffer::Wrap(values_),
shape_));
+ ASSERT_OK_AND_ASSIGN(
+ auto tensor_partial,
+ Tensor::Make(value_type_, Buffer::Wrap(values_partial_),
shape_partial_));
+ ASSERT_EQ(tensor->strides(), tensor_strides_);
+ ASSERT_EQ(tensor_partial->strides(), tensor_strides_);
+ auto ext_type = fixed_shape_tensor(value_type_, cell_shape_, {}, dim_names_);
+ auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+ ASSERT_OK_AND_ASSIGN(auto ext_arr_partial,
+ FixedShapeTensorArray::FromTensor(tensor_partial));
+ ASSERT_OK(ext_arr->ValidateFull());
+ ASSERT_OK(ext_arr_partial->ValidateFull());
+
+ auto sliced =
internal::checked_pointer_cast<ExtensionArray>(ext_arr->Slice(0, 2));
+ auto partial =
internal::checked_pointer_cast<ExtensionArray>(ext_arr_partial);
+
+ ASSERT_TRUE(sliced->Equals(*partial));
+ ASSERT_OK(sliced->ValidateFull());
+ ASSERT_OK(partial->ValidateFull());
+ ASSERT_TRUE(sliced->storage()->Equals(*partial->storage()));
+ ASSERT_EQ(sliced->length(), partial->length());
+}
+
+TEST_F(TestExtensionType, RoudtripBatchFromTensor) {
+ auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+ ASSERT_OK_AND_ASSIGN(auto tensor, Tensor::Make(value_type_,
Buffer::Wrap(values_),
+ shape_, {}, {"n", "x", "y"}));
+ ASSERT_OK_AND_ASSIGN(auto ext_arr,
FixedShapeTensorArray::FromTensor(tensor));
+ ext_arr->data()->type = exact_ext_type;
+
+ auto ext_metadata =
+ key_value_metadata({{"ARROW:extension:name",
ext_type_->extension_name()},
+ {"ARROW:extension:metadata", serialized_}});
+ auto ext_field = field("f0", ext_type_, true, ext_metadata);
+ auto batch = RecordBatch::Make(schema({ext_field}), ext_arr->length(),
{ext_arr});
+ std::shared_ptr<RecordBatch> read_batch;
+ RoundtripBatch(batch, &read_batch);
+ CompareBatch(*batch, *read_batch, /*compare_metadata=*/true);
+}
+
+TEST_F(TestExtensionType, ComputeStrides) {
+ auto exact_ext_type =
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type_);
+
+ auto ext_type_1 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_));
+ auto ext_type_2 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), cell_shape_, {}, dim_names_));
+ auto ext_type_3 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int32(), cell_shape_, {}, dim_names_));
+ ASSERT_TRUE(ext_type_1->Equals(*ext_type_2));
+ ASSERT_FALSE(ext_type_1->Equals(*ext_type_3));
+
+ auto ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4, 7}, {}, {"x", "y", "z"}));
+ ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8}));
+ ext_type_4 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4, 7}, {0, 1, 2}, {"x", "y", "z"}));
+ ASSERT_EQ(ext_type_4->strides(), (std::vector<int64_t>{224, 56, 8}));
+
+ auto ext_type_5 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4, 7}, {1, 0, 2}));
+ ASSERT_EQ(ext_type_5->strides(), (std::vector<int64_t>{56, 224, 8}));
+ ASSERT_EQ(ext_type_5->Serialize(),
R"({"shape":[3,4,7],"permutation":[1,0,2]})");
+
+ auto ext_type_6 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int64(), {3, 4, 7}, {1, 2, 0}, {}));
+ ASSERT_EQ(ext_type_6->strides(), (std::vector<int64_t>{56, 8, 224}));
+ ASSERT_EQ(ext_type_6->Serialize(),
R"({"shape":[3,4,7],"permutation":[1,2,0]})");
+ auto ext_type_7 = internal::checked_pointer_cast<FixedShapeTensorType>(
+ fixed_shape_tensor(int32(), {3, 4, 7}, {2, 0, 1}, {}));
+ ASSERT_EQ(ext_type_7->strides(), (std::vector<int64_t>{4, 112, 16}));
+ ASSERT_EQ(ext_type_7->Serialize(),
R"({"shape":[3,4,7],"permutation":[2,0,1]})");
+}
+
} // namespace arrow