rok commented on code in PR #37533:
URL: https://github.com/apache/arrow/pull/37533#discussion_r1482266550


##########
cpp/src/arrow/extension/fixed_shape_tensor_test.cc:
##########
@@ -462,4 +543,121 @@ TEST_F(TestExtensionType, ToString) {
   ASSERT_EQ(expected_3, result_3);
 }
 
+TEST_F(TestExtensionType, GetScalar) {
+  auto ext_type = fixed_shape_tensor(value_type_, element_shape_, {}, 
dim_names_);
+
+  auto expected_data =
+      ArrayFromJSON(element_type_, "[[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 
22, 23]]");
+  auto storage_array = ArrayFromJSON(element_type_,
+                                     "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],"
+                                     "[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 
22, 23]]");
+
+  auto sub_array = ExtensionType::WrapArray(ext_type, expected_data);
+  auto array = ExtensionType::WrapArray(ext_type, storage_array);
+
+  ASSERT_OK_AND_ASSIGN(auto expected_scalar, sub_array->GetScalar(0));
+  ASSERT_OK_AND_ASSIGN(auto actual_scalar, array->GetScalar(1));
+
+  ASSERT_OK(actual_scalar->ValidateFull());
+  ASSERT_TRUE(actual_scalar->type->Equals(*ext_type));
+  ASSERT_TRUE(actual_scalar->is_valid);
+
+  ASSERT_OK(expected_scalar->ValidateFull());
+  ASSERT_TRUE(expected_scalar->type->Equals(*ext_type));
+  ASSERT_TRUE(expected_scalar->is_valid);
+
+  AssertTypeEqual(actual_scalar->type, ext_type);
+  ASSERT_TRUE(actual_scalar->Equals(*expected_scalar));
+}
+
+TEST_F(TestExtensionType, GetTensor) {
+  auto arr = ArrayFromJSON(element_type_,
+                           "[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],"
+                           "[12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 
23]]");
+  auto element_values =
+      std::vector<std::vector<int64_t>>{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
+                                        {12, 13, 14, 15, 16, 17, 18, 19, 20, 
21, 22, 23}};
+
+  auto ext_type = fixed_shape_tensor(value_type_, element_shape_, {}, 
dim_names_);
+  auto permuted_ext_type = fixed_shape_tensor(value_type_, {3, 4}, {1, 0}, 
{"x", "y"});
+  auto exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
+  auto exact_permuted_ext_type =
+      internal::checked_pointer_cast<FixedShapeTensorType>(permuted_ext_type);
+
+  auto array = std::static_pointer_cast<FixedShapeTensorArray>(
+      ExtensionType::WrapArray(ext_type, arr));
+  auto permuted_array = std::static_pointer_cast<FixedShapeTensorArray>(
+      ExtensionType::WrapArray(permuted_ext_type, arr));
+
+  for (size_t i = 0; i < element_values.size(); i++) {
+    // Get tensor from extension array with trivial permutation
+    ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(i));
+    auto actual_ext_scalar = 
internal::checked_pointer_cast<ExtensionScalar>(scalar);
+    ASSERT_OK_AND_ASSIGN(auto actual_tensor,
+                         exact_ext_type->MakeTensor(actual_ext_scalar));
+    ASSERT_OK_AND_ASSIGN(auto expected_tensor,
+                         Tensor::Make(value_type_, 
Buffer::Wrap(element_values[i]),
+                                      {3, 4}, {}, {"x", "y"}));
+    ASSERT_EQ(expected_tensor->shape(), actual_tensor->shape());
+    ASSERT_EQ(expected_tensor->dim_names(), actual_tensor->dim_names());
+    ASSERT_EQ(expected_tensor->strides(), actual_tensor->strides());
+    ASSERT_EQ(actual_tensor->strides(), std::vector<int64_t>({32, 8}));
+    ASSERT_EQ(expected_tensor->type(), actual_tensor->type());
+    ASSERT_TRUE(expected_tensor->Equals(*actual_tensor));
+
+    // Get tensor from extension array with non-trivial permutation
+    ASSERT_OK_AND_ASSIGN(auto expected_permuted_tensor,
+                         Tensor::Make(value_type_, 
Buffer::Wrap(element_values[i]),
+                                      {4, 3}, {8, 24}, {"y", "x"}));
+    ASSERT_OK_AND_ASSIGN(scalar, permuted_array->GetScalar(i));
+    ASSERT_OK_AND_ASSIGN(auto actual_permuted_tensor,
+                         exact_permuted_ext_type->MakeTensor(
+                             
internal::checked_pointer_cast<ExtensionScalar>(scalar)));
+    ASSERT_EQ(expected_permuted_tensor->strides(), 
actual_permuted_tensor->strides());
+    ASSERT_EQ(expected_permuted_tensor->shape(), 
actual_permuted_tensor->shape());
+    ASSERT_EQ(expected_permuted_tensor->dim_names(), 
actual_permuted_tensor->dim_names());
+    ASSERT_EQ(expected_permuted_tensor->type(), 
actual_permuted_tensor->type());
+    ASSERT_EQ(expected_permuted_tensor->is_contiguous(),
+              actual_permuted_tensor->is_contiguous());
+    ASSERT_EQ(expected_permuted_tensor->is_column_major(),
+              actual_permuted_tensor->is_column_major());
+    ASSERT_TRUE(expected_permuted_tensor->Equals(*actual_permuted_tensor));
+  }
+
+  // Test null values fail
+  auto element_type = fixed_size_list(int64(), 1);
+  auto fsla_arr = ArrayFromJSON(element_type, "[[1], [null], null]");
+  ext_type = fixed_shape_tensor(int64(), {1});
+  exact_ext_type = 
internal::checked_pointer_cast<FixedShapeTensorType>(ext_type);
+  auto ext_arr = ExtensionType::WrapArray(ext_type, fsla_arr);
+  auto tensor_array = std::static_pointer_cast<FixedShapeTensorArray>(ext_arr);

Review Comment:
   The error happened when calling `GetScalar(0)` on `FixedShapeTensorArray`. 
Changing the cast to:
   ```cpp
   auto tensor_array = internal::checked_pointer_cast<ExtensionArray>(ext_arr);
   ```
   resolved the `SIGSEGV` I was getting when doing:
   ```cpp
   auto tensor_array = internal::checked_pointer_cast<FixedShapeTensorArray>(x);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to