This is an automated email from the ASF dual-hosted git repository.
apitrou pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new 9b77edd5c6 GH-35450: [C++] Return error when
`RecordBatch::ToStructArray` called with mismatched column lengths (#36654)
9b77edd5c6 is described below
commit 9b77edd5c66095d72ade0f71b120792fe2a8f99c
Author: Ben Harkins <[email protected]>
AuthorDate: Thu Aug 10 11:33:08 2023 -0400
GH-35450: [C++] Return error when `RecordBatch::ToStructArray` called with
mismatched column lengths (#36654)
### Rationale for this change
If a `RecordBatch` is created with column lengths that don't match the
provided `num_rows` (technically invalid), then there are some circumstances
where `ToStructArray` will successfully return an array whose length doesn't
match `num_rows`. Instead, we should return an error.
### What changes are included in this PR?
* Add a small validation check to `ToStructArray` before constructing the
output array
* Add a test
### Are these changes tested?
Yes (tests are included)
### Are there any user-facing changes?
No
* Closes: #35450
Authored-by: benibus <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
---
cpp/src/arrow/record_batch.cc | 23 ++++++++++++++++++-----
cpp/src/arrow/record_batch_test.cc | 25 +++++++++++++++++++++++++
2 files changed, 43 insertions(+), 5 deletions(-)
diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc
index 683e72878b..1c5c8912e5 100644
--- a/cpp/src/arrow/record_batch.cc
+++ b/cpp/src/arrow/record_batch.cc
@@ -218,8 +218,25 @@ Result<std::shared_ptr<RecordBatch>>
RecordBatch::FromStructArray(
array->data()->child_data);
}
+namespace {
+
+Status ValidateColumnLength(const RecordBatch& batch, int i) {
+ const auto& array = *batch.column(i);
+ if (ARROW_PREDICT_FALSE(array.length() != batch.num_rows())) {
+ return Status::Invalid("Number of rows in column ", i,
+ " did not match batch: ", array.length(), " vs ",
+ batch.num_rows());
+ }
+ return Status::OK();
+}
+
+} // namespace
+
Result<std::shared_ptr<StructArray>> RecordBatch::ToStructArray() const {
if (num_columns() != 0) {
+ // Only check the first column because `StructArray::Make` already checks
that the
+ // child lengths are equal.
+ RETURN_NOT_OK(ValidateColumnLength(*this, 0));
return StructArray::Make(columns(), schema()->fields());
}
return std::make_shared<StructArray>(arrow::struct_({}), num_rows_,
@@ -301,12 +318,8 @@ namespace {
Status ValidateBatch(const RecordBatch& batch, bool full_validation) {
for (int i = 0; i < batch.num_columns(); ++i) {
+ RETURN_NOT_OK(ValidateColumnLength(batch, i));
const auto& array = *batch.column(i);
- if (array.length() != batch.num_rows()) {
- return Status::Invalid("Number of rows in column ", i,
- " did not match batch: ", array.length(), " vs ",
- batch.num_rows());
- }
const auto& schema_type = batch.schema()->field(i)->type();
if (!array.type()->Equals(schema_type)) {
return Status::Invalid("Column ", i,
diff --git a/cpp/src/arrow/record_batch_test.cc
b/cpp/src/arrow/record_batch_test.cc
index 4975e94325..e8180c6740 100644
--- a/cpp/src/arrow/record_batch_test.cc
+++ b/cpp/src/arrow/record_batch_test.cc
@@ -413,6 +413,31 @@ TEST_F(TestRecordBatch, MakeEmpty) {
ASSERT_EQ(empty->num_rows(), 0);
}
+// See: https://github.com/apache/arrow/issues/35450
+TEST_F(TestRecordBatch, ToStructArrayMismatchedColumnLengths) {
+ constexpr int kNumRows = 5;
+ FieldVector fields = {field("x", int64()), field("y", int64())};
+ ArrayVector columns = {
+ ArrayFromJSON(int64(), "[0, 1, 2, 3, 4]"),
+ ArrayFromJSON(int64(), "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]"),
+ };
+
+ // Sanity check
+ auto batch = RecordBatch::Make(schema({fields[0]}), kNumRows, {columns[0]});
+ ASSERT_OK_AND_ASSIGN(auto array, batch->ToStructArray());
+ ASSERT_EQ(array->length(), kNumRows);
+
+ // One column with a mismatched length
+ batch = RecordBatch::Make(schema({fields[1]}), kNumRows, {columns[1]});
+ ASSERT_RAISES(Invalid, batch->ToStructArray());
+ // Mix of columns with matching and non-matching lengths
+ batch = RecordBatch::Make(schema(fields), kNumRows, columns);
+ ASSERT_RAISES(Invalid, batch->ToStructArray());
+ std::swap(columns[0], columns[1]);
+ batch = RecordBatch::Make(schema(fields), kNumRows, columns);
+ ASSERT_RAISES(Invalid, batch->ToStructArray());
+}
+
class TestRecordBatchReader : public ::testing::Test {
public:
void SetUp() override { MakeBatchesAndReader(100); }