This is an automated email from the ASF dual-hosted git repository.

raulcd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/main by this push:
     new 41035d4be1 GH-49716: [C++] FixedShapeTensorType::Deserialize should 
strictly validate serialized metadata (#49718)
41035d4be1 is described below

commit 41035d4be11d47e8e49eea7fde5f9fb74ab6d74d
Author: Rok Mihevc <[email protected]>
AuthorDate: Tue Apr 14 14:30:42 2026 +0200

    GH-49716: [C++] FixedShapeTensorType::Deserialize should strictly validate 
serialized metadata (#49718)
    
    ### Rationale for this change
    
    FixedShapeTensorType::Deserialize should validate input from unknown 
sources.
    
    ### What changes are included in this PR?
    
    Adds stricter deserialization valideation.
    
    ### Are these changes tested?
    
    Yes. New tests are added.
    
    ### Are there any user-facing changes?
    
    Stricter validation should not be noticed if metadata is correct as per 
spec of fixed_shape_tensor.
    * GitHub Issue: #49716
    
    Authored-by: Rok Mihevc <[email protected]>
    Signed-off-by: Raúl Cumplido <[email protected]>
---
 cpp/src/arrow/extension/fixed_shape_tensor.cc      | 66 ++++++++++++---
 .../arrow/extension/tensor_extension_array_test.cc | 93 ++++++++++++++++++++++
 cpp/src/arrow/extension/tensor_internal.cc         | 32 ++++++--
 cpp/src/arrow/extension/tensor_internal.h          | 14 ++++
 cpp/src/arrow/extension/variable_shape_tensor.cc   | 19 +++--
 5 files changed, 202 insertions(+), 22 deletions(-)

diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc 
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index 5be855ffcb..5446169887 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+#include <limits>
 #include <numeric>
 #include <sstream>
 
@@ -109,8 +110,8 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Deserialize(
     return Status::Invalid("Expected FixedSizeList storage type, got ",
                            storage_type->ToString());
   }
-  auto value_type =
-      
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
+  auto fsl_type = 
internal::checked_pointer_cast<FixedSizeListType>(storage_type);
+  auto value_type = fsl_type->value_type();
   rj::Document document;
   if (document.Parse(serialized_data.data(), 
serialized_data.length()).HasParseError() ||
       !document.IsObject() || !document.HasMember("shape") ||
@@ -119,21 +120,45 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Deserialize(
   }
 
   std::vector<int64_t> shape;
-  for (auto& x : document["shape"].GetArray()) {
+  for (const auto& x : document["shape"].GetArray()) {
+    if (!x.IsInt64()) {
+      return Status::Invalid("shape must contain integers, got ",
+                             internal::JsonTypeName(x));
+    }
     shape.emplace_back(x.GetInt64());
   }
+
   std::vector<int64_t> permutation;
   if (document.HasMember("permutation")) {
-    for (auto& x : document["permutation"].GetArray()) {
+    const auto& json_permutation = document["permutation"];
+    if (!json_permutation.IsArray()) {
+      return Status::Invalid("permutation must be an array, got ",
+                             internal::JsonTypeName(json_permutation));
+    }
+    for (const auto& x : json_permutation.GetArray()) {
+      if (!x.IsInt64()) {
+        return Status::Invalid("permutation must contain integers, got ",
+                               internal::JsonTypeName(x));
+      }
       permutation.emplace_back(x.GetInt64());
     }
     if (shape.size() != permutation.size()) {
       return Status::Invalid("Invalid permutation");
     }
+    RETURN_NOT_OK(internal::IsPermutationValid(permutation));
   }
   std::vector<std::string> dim_names;
   if (document.HasMember("dim_names")) {
-    for (auto& x : document["dim_names"].GetArray()) {
+    const auto& json_dim_names = document["dim_names"];
+    if (!json_dim_names.IsArray()) {
+      return Status::Invalid("dim_names must be an array, got ",
+                             internal::JsonTypeName(json_dim_names));
+    }
+    for (const auto& x : json_dim_names.GetArray()) {
+      if (!x.IsString()) {
+        return Status::Invalid("dim_names must contain strings, got ",
+                               internal::JsonTypeName(x));
+      }
       dim_names.emplace_back(x.GetString());
     }
     if (shape.size() != dim_names.size()) {
@@ -141,7 +166,20 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Deserialize(
     }
   }
 
-  return fixed_shape_tensor(value_type, shape, permutation, dim_names);
+  // Validate product of shape dimensions matches storage type list_size.
+  // This check is intentionally after field parsing so that metadata-level 
errors
+  // (type mismatches, size mismatches) are reported first.
+  ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make(
+                                           value_type, shape, permutation, 
dim_names));
+  const auto& fst_type = internal::checked_cast<const 
FixedShapeTensorType&>(*ext_type);
+  ARROW_ASSIGN_OR_RAISE(const int64_t expected_size,
+                        internal::ComputeShapeProduct(fst_type.shape()));
+  if (expected_size != fsl_type->list_size()) {
+    return Status::Invalid("Product of shape dimensions (", expected_size,
+                           ") does not match FixedSizeList size (", 
fsl_type->list_size(),
+                           ")");
+  }
+  return ext_type;
 }
 
 std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
@@ -310,8 +348,7 @@ const Result<std::shared_ptr<Tensor>> 
FixedShapeTensorArray::ToTensor() const {
   }
 
   std::vector<int64_t> shape = ext_type.shape();
-  auto cell_size = std::accumulate(shape.begin(), shape.end(), 
static_cast<int64_t>(1),
-                                   std::multiplies<>());
+  ARROW_ASSIGN_OR_RAISE(const int64_t cell_size, 
internal::ComputeShapeProduct(shape));
   shape.insert(shape.begin(), 1, this->length());
   internal::Permute<int64_t>(permutation, &shape);
 
@@ -330,6 +367,11 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Make(
     const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& 
shape,
     const std::vector<int64_t>& permutation, const std::vector<std::string>& 
dim_names) {
   const size_t ndim = shape.size();
+  for (auto dim : shape) {
+    if (dim < 0) {
+      return Status::Invalid("shape must have non-negative values, got ", dim);
+    }
+  }
   if (!permutation.empty() && ndim != permutation.size()) {
     return Status::Invalid("permutation size must match shape size. Expected: 
", ndim,
                            " Got: ", permutation.size());
@@ -342,8 +384,12 @@ Result<std::shared_ptr<DataType>> 
FixedShapeTensorType::Make(
     RETURN_NOT_OK(internal::IsPermutationValid(permutation));
   }
 
-  const int64_t size = std::accumulate(shape.begin(), shape.end(),
-                                       static_cast<int64_t>(1), 
std::multiplies<>());
+  ARROW_ASSIGN_OR_RAISE(const int64_t size, 
internal::ComputeShapeProduct(shape));
+  if (size > std::numeric_limits<int32_t>::max()) {
+    return Status::Invalid("Product of shape dimensions (", size,
+                           ") exceeds maximum FixedSizeList size (",
+                           std::numeric_limits<int32_t>::max(), ")");
+  }
   return std::make_shared<FixedShapeTensorType>(value_type, 
static_cast<int32_t>(size),
                                                 shape, permutation, dim_names);
 }
diff --git a/cpp/src/arrow/extension/tensor_extension_array_test.cc 
b/cpp/src/arrow/extension/tensor_extension_array_test.cc
index 5c6dbe2162..531fc3c01c 100644
--- a/cpp/src/arrow/extension/tensor_extension_array_test.cc
+++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc
@@ -219,6 +219,73 @@ TEST_F(TestFixedShapeTensorType, 
MetadataSerializationRoundtrip) {
   CheckDeserializationRaises(ext_type_, storage_type,
                              R"({"shape":[3],"dim_names":["x","y"]})",
                              "Invalid dim_names");
+
+  // Validate shape values must be integers. Error message should include the
+  // JSON type name of the offending value.
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})",
+                             "shape must contain integers, got Number");
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})",
+                             "shape must contain integers, got String");
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})",
+                             "shape must contain integers, got Null");
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[true]})",
+                             "shape must contain integers, got True");
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[false]})",
+                             "shape must contain integers, got False");
+
+  // Validate shape values must be non-negative
+  CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1), 
R"({"shape":[-1]})",
+                             "shape must have non-negative values");
+
+  // Validate product of shape matches storage list_size
+  CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})",
+                             "Product of shape dimensions");
+
+  // Validate permutation member must be an array with integer values
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":"invalid"})",
+                             "permutation must be an array, got String");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":{"a":1}})",
+                             "permutation must be an array, got Object");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":[1.5,0.5]})",
+                             "permutation must contain integers, got Number");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":["a","b"]})",
+                             "permutation must contain integers, got String");
+
+  // Validate permutation values must be unique integers in [0, N-1]
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":[0,0]})",
+                             "Permutation indices");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":[0,5]})",
+                             "Permutation indices");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"permutation":[-1,0]})",
+                             "Permutation indices");
+
+  // Validate dim_names member must be an array with string values
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"dim_names":"invalid"})",
+                             "dim_names must be an array, got String");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"dim_names":[1,2]})",
+                             "dim_names must contain strings, got Number");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"shape":[3,4],"dim_names":[null,null]})",
+                             "dim_names must contain strings, got Null");
+}
+
+TEST_F(TestFixedShapeTensorType, MakeValidatesShape) {
+  // Negative shape values should be rejected
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("shape must have non-negative values"),
+      FixedShapeTensorType::Make(value_type_, {-1}));
+  EXPECT_RAISES_WITH_MESSAGE_THAT(
+      Invalid, testing::HasSubstr("shape must have non-negative values"),
+      FixedShapeTensorType::Make(value_type_, {3, -1, 4}));
 }
 
 TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
@@ -794,6 +861,32 @@ TEST_F(TestVariableShapeTensorType, 
MetadataSerializationRoundtrip) {
                              "Invalid: permutation");
   CheckDeserializationRaises(ext_type_, storage_type, 
R"({"dim_names":["x","y"]})",
                              "Invalid: dim_names");
+
+  // Validate permutation member must be an array with integer values. Error
+  // message should include the JSON type name of the offending value.
+  CheckDeserializationRaises(ext_type_, storage_type, 
R"({"permutation":"invalid"})",
+                             "permutation must be an array, got String");
+  CheckDeserializationRaises(ext_type_, storage_type, 
R"({"permutation":[1.5,0.5,2.5]})",
+                             "permutation must contain integers, got Number");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"permutation":[null,null,null]})",
+                             "permutation must contain integers, got Null");
+
+  // Validate dim_names member must be an array with string values
+  CheckDeserializationRaises(ext_type_, storage_type, 
R"({"dim_names":"invalid"})",
+                             "dim_names must be an array, got String");
+  CheckDeserializationRaises(ext_type_, storage_type, 
R"({"dim_names":[1,2,3]})",
+                             "dim_names must contain strings, got Number");
+
+  // Validate uniform_shape member must be an array with integer-or-null values
+  CheckDeserializationRaises(ext_type_, storage_type, 
R"({"uniform_shape":"invalid"})",
+                             "uniform_shape must be an array, got String");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"uniform_shape":[1.5,null,null]})",
+                             "uniform_shape must contain integers or nulls, 
got Number");
+  CheckDeserializationRaises(ext_type_, storage_type,
+                             R"({"uniform_shape":["x",null,null]})",
+                             "uniform_shape must contain integers or nulls, 
got String");
 }
 
 TEST_F(TestVariableShapeTensorType, RoundtripBatch) {
diff --git a/cpp/src/arrow/extension/tensor_internal.cc 
b/cpp/src/arrow/extension/tensor_internal.cc
index 37862b7689..e94ea9a1d1 100644
--- a/cpp/src/arrow/extension/tensor_internal.cc
+++ b/cpp/src/arrow/extension/tensor_internal.cc
@@ -30,6 +30,31 @@
 
 namespace arrow::internal {
 
+namespace {
+
+// Names indexed by rapidjson::Type enum value:
+// kNullType=0, kFalseType=1, kTrueType=2, kObjectType=3,
+// kArrayType=4, kStringType=5, kNumberType=6.
+constexpr const char* kJsonTypeNames[] = {"Null",  "False",  "True",  "Object",
+                                          "Array", "String", "Number"};
+
+}  // namespace
+
+const char* JsonTypeName(const ::arrow::rapidjson::Value& v) {
+  return kJsonTypeNames[v.GetType()];
+}
+
+Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape) {
+  int64_t product = 1;
+  for (const auto dim : shape) {
+    if (MultiplyWithOverflow(product, dim, &product)) {
+      return Status::Invalid(
+          "Product of tensor shape dimensions would not fit in 64-bit 
integer");
+    }
+  }
+  return product;
+}
+
 bool IsPermutationTrivial(std::span<const int64_t> permutation) {
   for (size_t i = 1; i < permutation.size(); ++i) {
     if (permutation[i - 1] + 1 != permutation[i]) {
@@ -105,12 +130,7 @@ Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const 
Array& data_array,
                                                   const DataType& value_type,
                                                   std::span<const int64_t> 
shape) {
   const int64_t byte_width = value_type.byte_width();
-  int64_t size = 1;
-  for (const auto dim : shape) {
-    if (MultiplyWithOverflow(size, dim, &size)) {
-      return Status::Invalid("Tensor size would not fit in 64-bit integer");
-    }
-  }
+  ARROW_ASSIGN_OR_RAISE(const int64_t size, ComputeShapeProduct(shape));
   if (size != data_array.length()) {
     return Status::Invalid("Expected data array of length ", size, ", got ",
                            data_array.length());
diff --git a/cpp/src/arrow/extension/tensor_internal.h 
b/cpp/src/arrow/extension/tensor_internal.h
index b5ed5ebe11..19665bf2cd 100644
--- a/cpp/src/arrow/extension/tensor_internal.h
+++ b/cpp/src/arrow/extension/tensor_internal.h
@@ -21,11 +21,25 @@
 #include <span>
 #include <vector>
 
+#include "arrow/json/rapidjson_defs.h"  // IWYU pragma: keep
 #include "arrow/result.h"
 #include "arrow/type_fwd.h"
 
+#include <rapidjson/document.h>
+
 namespace arrow::internal {
 
+/// \brief Return the name of a RapidJSON value's type (e.g., "Null", "Array", 
"Number").
+ARROW_EXPORT
+const char* JsonTypeName(const ::arrow::rapidjson::Value& v);
+
+/// \brief Compute the product of the given shape dimensions.
+///
+/// Returns Status::Invalid if the product would overflow int64_t.
+/// An empty shape returns 1 (the multiplicative identity).
+ARROW_EXPORT
+Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape);
+
 ARROW_EXPORT
 bool IsPermutationTrivial(std::span<const int64_t> permutation);
 
diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc 
b/cpp/src/arrow/extension/variable_shape_tensor.cc
index 7e27bbdb74..b1b12583d7 100644
--- a/cpp/src/arrow/extension/variable_shape_tensor.cc
+++ b/cpp/src/arrow/extension/variable_shape_tensor.cc
@@ -159,26 +159,31 @@ Result<std::shared_ptr<DataType>> 
VariableShapeTensorType::Deserialize(
   if (document.HasMember("permutation")) {
     const auto& json_permutation = document["permutation"];
     if (!json_permutation.IsArray()) {
-      return Status::Invalid("permutation must be an array");
+      return Status::Invalid("permutation must be an array, got ",
+                             internal::JsonTypeName(json_permutation));
     }
     permutation.reserve(ndim);
     for (const auto& x : json_permutation.GetArray()) {
       if (!x.IsInt64()) {
-        return Status::Invalid("permutation must contain integers");
+        return Status::Invalid("permutation must contain integers, got ",
+                               internal::JsonTypeName(x));
       }
       permutation.emplace_back(x.GetInt64());
     }
+    RETURN_NOT_OK(internal::IsPermutationValid(permutation));
   }
   std::vector<std::string> dim_names;
   if (document.HasMember("dim_names")) {
     const auto& json_dim_names = document["dim_names"];
     if (!json_dim_names.IsArray()) {
-      return Status::Invalid("dim_names must be an array");
+      return Status::Invalid("dim_names must be an array, got ",
+                             internal::JsonTypeName(json_dim_names));
     }
     dim_names.reserve(ndim);
     for (const auto& x : json_dim_names.GetArray()) {
       if (!x.IsString()) {
-        return Status::Invalid("dim_names must contain strings");
+        return Status::Invalid("dim_names must contain strings, got ",
+                               internal::JsonTypeName(x));
       }
       dim_names.emplace_back(x.GetString());
     }
@@ -188,7 +193,8 @@ Result<std::shared_ptr<DataType>> 
VariableShapeTensorType::Deserialize(
   if (document.HasMember("uniform_shape")) {
     const auto& json_uniform_shape = document["uniform_shape"];
     if (!json_uniform_shape.IsArray()) {
-      return Status::Invalid("uniform_shape must be an array");
+      return Status::Invalid("uniform_shape must be an array, got ",
+                             internal::JsonTypeName(json_uniform_shape));
     }
     uniform_shape.reserve(ndim);
     for (const auto& x : json_uniform_shape.GetArray()) {
@@ -197,7 +203,8 @@ Result<std::shared_ptr<DataType>> 
VariableShapeTensorType::Deserialize(
       } else if (x.IsInt64()) {
         uniform_shape.emplace_back(x.GetInt64());
       } else {
-        return Status::Invalid("uniform_shape must contain integers or nulls");
+        return Status::Invalid("uniform_shape must contain integers or nulls, 
got ",
+                               internal::JsonTypeName(x));
       }
     }
   }

Reply via email to