rok commented on code in PR #37533: URL: https://github.com/apache/arrow/pull/37533#discussion_r1435402596
########## cpp/src/arrow/extension/fixed_shape_tensor.cc: ########## @@ -293,40 +335,49 @@ const Result<std::shared_ptr<Tensor>> FixedShapeTensorArray::ToTensor() const { // To convert an array of n dimensional tensors to a n+1 dimensional tensor we // interpret the array's length as the first dimension the new tensor. - auto ext_arr = std::static_pointer_cast<FixedSizeListArray>(this->storage()); - auto ext_type = internal::checked_pointer_cast<FixedShapeTensorType>(this->type()); - ARROW_RETURN_IF(!is_fixed_width(*ext_arr->value_type()), - Status::Invalid(ext_arr->value_type()->ToString(), - " is not valid data type for a tensor")); - auto permutation = ext_type->permutation(); + const auto ext_type = + internal::checked_pointer_cast<FixedShapeTensorType>(this->type()); + const auto value_type = ext_type->value_type(); + ARROW_RETURN_IF( + !is_fixed_width(*value_type), + Status::Invalid(value_type->ToString(), " is not valid data type for a tensor")); - std::vector<std::string> dim_names; - if (!ext_type->dim_names().empty()) { - for (auto i : permutation) { - dim_names.emplace_back(ext_type->dim_names()[i]); + std::vector<int64_t> permutation = ext_type->permutation(); + if (permutation.empty()) { + for (int64_t i = 0; i < static_cast<int64_t>(ext_type->ndim()); i++) { + permutation.emplace_back(i); } - dim_names.insert(dim_names.begin(), 1, ""); - } else { - dim_names = {}; } + for (int64_t i = 0; i < static_cast<int64_t>(ext_type->ndim()); i++) { + permutation[i] += 1; + } + permutation.insert(permutation.begin(), 1, 0); - std::vector<int64_t> shape; - for (int64_t& i : permutation) { - shape.emplace_back(ext_type->shape()[i]); - ++i; + std::vector<std::string> dim_names = ext_type->dim_names(); + if (!dim_names.empty()) { + dim_names.insert(dim_names.begin(), 1, ""); + internal::Permute<std::string>(permutation, &dim_names); } + + std::vector<int64_t> shape = ext_type->shape(); + auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1), + std::multiplies<>()); shape.insert(shape.begin(), 1, this->length()); - permutation.insert(permutation.begin(), 1, 0); + internal::Permute<int64_t>(permutation, &shape); std::vector<int64_t> tensor_strides; - auto value_type = internal::checked_pointer_cast<FixedWidthType>(ext_arr->value_type()); + const auto fw_value_type = internal::checked_pointer_cast<FixedWidthType>(value_type); ARROW_RETURN_NOT_OK( - ComputeStrides(*value_type.get(), shape, permutation, &tensor_strides)); - ARROW_ASSIGN_OR_RAISE(auto buffers, ext_arr->Flatten()); + ComputeStrides(*fw_value_type.get(), shape, permutation, &tensor_strides)); + ARROW_ASSIGN_OR_RAISE( - auto tensor, Tensor::Make(ext_arr->value_type(), buffers->data()->buffers[1], shape, - tensor_strides, dim_names)); - return tensor; + const auto flattened_storage_array, + internal::checked_pointer_cast<FixedSizeListArray>(this->storage())->Flatten()); Review Comment: Done. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org