This is an automated email from the ASF dual-hosted git repository.
raulcd pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 41035d4be1 GH-49716: [C++] FixedShapeTensorType::Deserialize should
strictly validate serialized metadata (#49718)
41035d4be1 is described below
commit 41035d4be11d47e8e49eea7fde5f9fb74ab6d74d
Author: Rok Mihevc <[email protected]>
AuthorDate: Tue Apr 14 14:30:42 2026 +0200
GH-49716: [C++] FixedShapeTensorType::Deserialize should strictly validate
serialized metadata (#49718)
### Rationale for this change
FixedShapeTensorType::Deserialize should validate input from unknown
sources.
### What changes are included in this PR?
Adds stricter deserialization valideation.
### Are these changes tested?
Yes. New tests are added.
### Are there any user-facing changes?
Stricter validation should not be noticed if metadata is correct as per
spec of fixed_shape_tensor.
* GitHub Issue: #49716
Authored-by: Rok Mihevc <[email protected]>
Signed-off-by: Raúl Cumplido <[email protected]>
---
cpp/src/arrow/extension/fixed_shape_tensor.cc | 66 ++++++++++++---
.../arrow/extension/tensor_extension_array_test.cc | 93 ++++++++++++++++++++++
cpp/src/arrow/extension/tensor_internal.cc | 32 ++++++--
cpp/src/arrow/extension/tensor_internal.h | 14 ++++
cpp/src/arrow/extension/variable_shape_tensor.cc | 19 +++--
5 files changed, 202 insertions(+), 22 deletions(-)
diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc
b/cpp/src/arrow/extension/fixed_shape_tensor.cc
index 5be855ffcb..5446169887 100644
--- a/cpp/src/arrow/extension/fixed_shape_tensor.cc
+++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+#include <limits>
#include <numeric>
#include <sstream>
@@ -109,8 +110,8 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Deserialize(
return Status::Invalid("Expected FixedSizeList storage type, got ",
storage_type->ToString());
}
- auto value_type =
-
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
+ auto fsl_type =
internal::checked_pointer_cast<FixedSizeListType>(storage_type);
+ auto value_type = fsl_type->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(),
serialized_data.length()).HasParseError() ||
!document.IsObject() || !document.HasMember("shape") ||
@@ -119,21 +120,45 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Deserialize(
}
std::vector<int64_t> shape;
- for (auto& x : document["shape"].GetArray()) {
+ for (const auto& x : document["shape"].GetArray()) {
+ if (!x.IsInt64()) {
+ return Status::Invalid("shape must contain integers, got ",
+ internal::JsonTypeName(x));
+ }
shape.emplace_back(x.GetInt64());
}
+
std::vector<int64_t> permutation;
if (document.HasMember("permutation")) {
- for (auto& x : document["permutation"].GetArray()) {
+ const auto& json_permutation = document["permutation"];
+ if (!json_permutation.IsArray()) {
+ return Status::Invalid("permutation must be an array, got ",
+ internal::JsonTypeName(json_permutation));
+ }
+ for (const auto& x : json_permutation.GetArray()) {
+ if (!x.IsInt64()) {
+ return Status::Invalid("permutation must contain integers, got ",
+ internal::JsonTypeName(x));
+ }
permutation.emplace_back(x.GetInt64());
}
if (shape.size() != permutation.size()) {
return Status::Invalid("Invalid permutation");
}
+ RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
- for (auto& x : document["dim_names"].GetArray()) {
+ const auto& json_dim_names = document["dim_names"];
+ if (!json_dim_names.IsArray()) {
+ return Status::Invalid("dim_names must be an array, got ",
+ internal::JsonTypeName(json_dim_names));
+ }
+ for (const auto& x : json_dim_names.GetArray()) {
+ if (!x.IsString()) {
+ return Status::Invalid("dim_names must contain strings, got ",
+ internal::JsonTypeName(x));
+ }
dim_names.emplace_back(x.GetString());
}
if (shape.size() != dim_names.size()) {
@@ -141,7 +166,20 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Deserialize(
}
}
- return fixed_shape_tensor(value_type, shape, permutation, dim_names);
+ // Validate product of shape dimensions matches storage type list_size.
+ // This check is intentionally after field parsing so that metadata-level
errors
+ // (type mismatches, size mismatches) are reported first.
+ ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make(
+ value_type, shape, permutation,
dim_names));
+ const auto& fst_type = internal::checked_cast<const
FixedShapeTensorType&>(*ext_type);
+ ARROW_ASSIGN_OR_RAISE(const int64_t expected_size,
+ internal::ComputeShapeProduct(fst_type.shape()));
+ if (expected_size != fsl_type->list_size()) {
+ return Status::Invalid("Product of shape dimensions (", expected_size,
+ ") does not match FixedSizeList size (",
fsl_type->list_size(),
+ ")");
+ }
+ return ext_type;
}
std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
@@ -310,8 +348,7 @@ const Result<std::shared_ptr<Tensor>>
FixedShapeTensorArray::ToTensor() const {
}
std::vector<int64_t> shape = ext_type.shape();
- auto cell_size = std::accumulate(shape.begin(), shape.end(),
static_cast<int64_t>(1),
- std::multiplies<>());
+ ARROW_ASSIGN_OR_RAISE(const int64_t cell_size,
internal::ComputeShapeProduct(shape));
shape.insert(shape.begin(), 1, this->length());
internal::Permute<int64_t>(permutation, &shape);
@@ -330,6 +367,11 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>&
shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>&
dim_names) {
const size_t ndim = shape.size();
+ for (auto dim : shape) {
+ if (dim < 0) {
+ return Status::Invalid("shape must have non-negative values, got ", dim);
+ }
+ }
if (!permutation.empty() && ndim != permutation.size()) {
return Status::Invalid("permutation size must match shape size. Expected:
", ndim,
" Got: ", permutation.size());
@@ -342,8 +384,12 @@ Result<std::shared_ptr<DataType>>
FixedShapeTensorType::Make(
RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}
- const int64_t size = std::accumulate(shape.begin(), shape.end(),
- static_cast<int64_t>(1),
std::multiplies<>());
+ ARROW_ASSIGN_OR_RAISE(const int64_t size,
internal::ComputeShapeProduct(shape));
+ if (size > std::numeric_limits<int32_t>::max()) {
+ return Status::Invalid("Product of shape dimensions (", size,
+ ") exceeds maximum FixedSizeList size (",
+ std::numeric_limits<int32_t>::max(), ")");
+ }
return std::make_shared<FixedShapeTensorType>(value_type,
static_cast<int32_t>(size),
shape, permutation, dim_names);
}
diff --git a/cpp/src/arrow/extension/tensor_extension_array_test.cc
b/cpp/src/arrow/extension/tensor_extension_array_test.cc
index 5c6dbe2162..531fc3c01c 100644
--- a/cpp/src/arrow/extension/tensor_extension_array_test.cc
+++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc
@@ -219,6 +219,73 @@ TEST_F(TestFixedShapeTensorType,
MetadataSerializationRoundtrip) {
CheckDeserializationRaises(ext_type_, storage_type,
R"({"shape":[3],"dim_names":["x","y"]})",
"Invalid dim_names");
+
+ // Validate shape values must be integers. Error message should include the
+ // JSON type name of the offending value.
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})",
+ "shape must contain integers, got Number");
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})",
+ "shape must contain integers, got String");
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})",
+ "shape must contain integers, got Null");
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[true]})",
+ "shape must contain integers, got True");
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[false]})",
+ "shape must contain integers, got False");
+
+ // Validate shape values must be non-negative
+ CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1),
R"({"shape":[-1]})",
+ "shape must have non-negative values");
+
+ // Validate product of shape matches storage list_size
+ CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})",
+ "Product of shape dimensions");
+
+ // Validate permutation member must be an array with integer values
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":"invalid"})",
+ "permutation must be an array, got String");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":{"a":1}})",
+ "permutation must be an array, got Object");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":[1.5,0.5]})",
+ "permutation must contain integers, got Number");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":["a","b"]})",
+ "permutation must contain integers, got String");
+
+ // Validate permutation values must be unique integers in [0, N-1]
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":[0,0]})",
+ "Permutation indices");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":[0,5]})",
+ "Permutation indices");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"permutation":[-1,0]})",
+ "Permutation indices");
+
+ // Validate dim_names member must be an array with string values
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"dim_names":"invalid"})",
+ "dim_names must be an array, got String");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"dim_names":[1,2]})",
+ "dim_names must contain strings, got Number");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"shape":[3,4],"dim_names":[null,null]})",
+ "dim_names must contain strings, got Null");
+}
+
+TEST_F(TestFixedShapeTensorType, MakeValidatesShape) {
+ // Negative shape values should be rejected
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("shape must have non-negative values"),
+ FixedShapeTensorType::Make(value_type_, {-1}));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(
+ Invalid, testing::HasSubstr("shape must have non-negative values"),
+ FixedShapeTensorType::Make(value_type_, {3, -1, 4}));
}
TEST_F(TestFixedShapeTensorType, RoundtripBatch) {
@@ -794,6 +861,32 @@ TEST_F(TestVariableShapeTensorType,
MetadataSerializationRoundtrip) {
"Invalid: permutation");
CheckDeserializationRaises(ext_type_, storage_type,
R"({"dim_names":["x","y"]})",
"Invalid: dim_names");
+
+ // Validate permutation member must be an array with integer values. Error
+ // message should include the JSON type name of the offending value.
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"permutation":"invalid"})",
+ "permutation must be an array, got String");
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"permutation":[1.5,0.5,2.5]})",
+ "permutation must contain integers, got Number");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"permutation":[null,null,null]})",
+ "permutation must contain integers, got Null");
+
+ // Validate dim_names member must be an array with string values
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"dim_names":"invalid"})",
+ "dim_names must be an array, got String");
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"dim_names":[1,2,3]})",
+ "dim_names must contain strings, got Number");
+
+ // Validate uniform_shape member must be an array with integer-or-null values
+ CheckDeserializationRaises(ext_type_, storage_type,
R"({"uniform_shape":"invalid"})",
+ "uniform_shape must be an array, got String");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"uniform_shape":[1.5,null,null]})",
+ "uniform_shape must contain integers or nulls,
got Number");
+ CheckDeserializationRaises(ext_type_, storage_type,
+ R"({"uniform_shape":["x",null,null]})",
+ "uniform_shape must contain integers or nulls,
got String");
}
TEST_F(TestVariableShapeTensorType, RoundtripBatch) {
diff --git a/cpp/src/arrow/extension/tensor_internal.cc
b/cpp/src/arrow/extension/tensor_internal.cc
index 37862b7689..e94ea9a1d1 100644
--- a/cpp/src/arrow/extension/tensor_internal.cc
+++ b/cpp/src/arrow/extension/tensor_internal.cc
@@ -30,6 +30,31 @@
namespace arrow::internal {
+namespace {
+
+// Names indexed by rapidjson::Type enum value:
+// kNullType=0, kFalseType=1, kTrueType=2, kObjectType=3,
+// kArrayType=4, kStringType=5, kNumberType=6.
+constexpr const char* kJsonTypeNames[] = {"Null", "False", "True", "Object",
+ "Array", "String", "Number"};
+
+} // namespace
+
+const char* JsonTypeName(const ::arrow::rapidjson::Value& v) {
+ return kJsonTypeNames[v.GetType()];
+}
+
+Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape) {
+ int64_t product = 1;
+ for (const auto dim : shape) {
+ if (MultiplyWithOverflow(product, dim, &product)) {
+ return Status::Invalid(
+ "Product of tensor shape dimensions would not fit in 64-bit
integer");
+ }
+ }
+ return product;
+}
+
bool IsPermutationTrivial(std::span<const int64_t> permutation) {
for (size_t i = 1; i < permutation.size(); ++i) {
if (permutation[i - 1] + 1 != permutation[i]) {
@@ -105,12 +130,7 @@ Result<std::shared_ptr<Buffer>> SliceTensorBuffer(const
Array& data_array,
const DataType& value_type,
std::span<const int64_t>
shape) {
const int64_t byte_width = value_type.byte_width();
- int64_t size = 1;
- for (const auto dim : shape) {
- if (MultiplyWithOverflow(size, dim, &size)) {
- return Status::Invalid("Tensor size would not fit in 64-bit integer");
- }
- }
+ ARROW_ASSIGN_OR_RAISE(const int64_t size, ComputeShapeProduct(shape));
if (size != data_array.length()) {
return Status::Invalid("Expected data array of length ", size, ", got ",
data_array.length());
diff --git a/cpp/src/arrow/extension/tensor_internal.h
b/cpp/src/arrow/extension/tensor_internal.h
index b5ed5ebe11..19665bf2cd 100644
--- a/cpp/src/arrow/extension/tensor_internal.h
+++ b/cpp/src/arrow/extension/tensor_internal.h
@@ -21,11 +21,25 @@
#include <span>
#include <vector>
+#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/result.h"
#include "arrow/type_fwd.h"
+#include <rapidjson/document.h>
+
namespace arrow::internal {
+/// \brief Return the name of a RapidJSON value's type (e.g., "Null", "Array",
"Number").
+ARROW_EXPORT
+const char* JsonTypeName(const ::arrow::rapidjson::Value& v);
+
+/// \brief Compute the product of the given shape dimensions.
+///
+/// Returns Status::Invalid if the product would overflow int64_t.
+/// An empty shape returns 1 (the multiplicative identity).
+ARROW_EXPORT
+Result<int64_t> ComputeShapeProduct(std::span<const int64_t> shape);
+
ARROW_EXPORT
bool IsPermutationTrivial(std::span<const int64_t> permutation);
diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc
b/cpp/src/arrow/extension/variable_shape_tensor.cc
index 7e27bbdb74..b1b12583d7 100644
--- a/cpp/src/arrow/extension/variable_shape_tensor.cc
+++ b/cpp/src/arrow/extension/variable_shape_tensor.cc
@@ -159,26 +159,31 @@ Result<std::shared_ptr<DataType>>
VariableShapeTensorType::Deserialize(
if (document.HasMember("permutation")) {
const auto& json_permutation = document["permutation"];
if (!json_permutation.IsArray()) {
- return Status::Invalid("permutation must be an array");
+ return Status::Invalid("permutation must be an array, got ",
+ internal::JsonTypeName(json_permutation));
}
permutation.reserve(ndim);
for (const auto& x : json_permutation.GetArray()) {
if (!x.IsInt64()) {
- return Status::Invalid("permutation must contain integers");
+ return Status::Invalid("permutation must contain integers, got ",
+ internal::JsonTypeName(x));
}
permutation.emplace_back(x.GetInt64());
}
+ RETURN_NOT_OK(internal::IsPermutationValid(permutation));
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
const auto& json_dim_names = document["dim_names"];
if (!json_dim_names.IsArray()) {
- return Status::Invalid("dim_names must be an array");
+ return Status::Invalid("dim_names must be an array, got ",
+ internal::JsonTypeName(json_dim_names));
}
dim_names.reserve(ndim);
for (const auto& x : json_dim_names.GetArray()) {
if (!x.IsString()) {
- return Status::Invalid("dim_names must contain strings");
+ return Status::Invalid("dim_names must contain strings, got ",
+ internal::JsonTypeName(x));
}
dim_names.emplace_back(x.GetString());
}
@@ -188,7 +193,8 @@ Result<std::shared_ptr<DataType>>
VariableShapeTensorType::Deserialize(
if (document.HasMember("uniform_shape")) {
const auto& json_uniform_shape = document["uniform_shape"];
if (!json_uniform_shape.IsArray()) {
- return Status::Invalid("uniform_shape must be an array");
+ return Status::Invalid("uniform_shape must be an array, got ",
+ internal::JsonTypeName(json_uniform_shape));
}
uniform_shape.reserve(ndim);
for (const auto& x : json_uniform_shape.GetArray()) {
@@ -197,7 +203,8 @@ Result<std::shared_ptr<DataType>>
VariableShapeTensorType::Deserialize(
} else if (x.IsInt64()) {
uniform_shape.emplace_back(x.GetInt64());
} else {
- return Status::Invalid("uniform_shape must contain integers or nulls");
+ return Status::Invalid("uniform_shape must contain integers or nulls,
got ",
+ internal::JsonTypeName(x));
}
}
}