bkietz commented on code in PR #37896:
URL: https://github.com/apache/arrow/pull/37896#discussion_r1339037926
##########
cpp/src/arrow/record_batch.cc:
##########
@@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close
failed");
}
+Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
+ const RecordBatchVector& batches, MemoryPool* pool) {
+ int64_t length = 0;
+ size_t n = batches.size();
+ if (n == 0) {
+ return Status::Invalid("Must pass at least one recordbatch");
+ }
+ if (n == 1) {
Review Comment:
Instead, please make this deep copy the batch to conform with
ConcatenateArrays https://github.com/apache/arrow/issues/37878
##########
cpp/src/arrow/record_batch.h:
##########
@@ -350,4 +350,12 @@ class ARROW_EXPORT RecordBatchReader {
Iterator<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema>
schema);
};
+/// \brief Concatenate recordbatches
Review Comment:
Nit: either use "record batches" or "RecordBatch"
##########
cpp/src/arrow/record_batch.cc:
##########
@@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close
failed");
}
+Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
+ const RecordBatchVector& batches, MemoryPool* pool) {
+ int64_t length = 0;
+ size_t n = batches.size();
+ if (n == 0) {
+ return Status::Invalid("Must pass at least one recordbatch");
+ }
+ if (n == 1) {
+ return batches[0];
+ }
+ int cols = batches[0]->num_columns();
+ auto schema = batches[0]->schema();
+ std::vector<std::shared_ptr<Array>> columns;
+ if (cols == 0) {
+ // special case: null batch, no data, just length
+ for (size_t i = 0; i < batches.size(); ++i) {
+ length += batches[i]->num_rows();
+ }
+ } else {
+ for (int col = 0; col < cols; ++col) {
+ ArrayVector data;
+ for (size_t i = 0; i < batches.size(); ++i) {
+ auto cur_schema = batches[i]->schema();
+ if (!schema->Equals(cur_schema)) {
+ return Status::Invalid(
+ "RecordBatch index ", i, " schema is ", cur_schema->ToString(),
+ ", did not match index 0 recordbatch schema: ",
schema->ToString());
+ }
+ auto column_data = batches[i]->column(col);
+ data.push_back(column_data);
+ }
+ auto array = Concatenate(data, pool).ValueOrDie();
+ length = array->length();
+ columns.push_back(array);
+ }
+ }
+ return RecordBatch::Make(std::move(schema), length, columns);
Review Comment:
```suggestion
return RecordBatch::Make(std::move(schema), length, std::move(columns));
```
##########
cpp/src/arrow/record_batch.cc:
##########
@@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close
failed");
}
+Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
+ const RecordBatchVector& batches, MemoryPool* pool) {
+ int64_t length = 0;
+ size_t n = batches.size();
+ if (n == 0) {
+ return Status::Invalid("Must pass at least one recordbatch");
+ }
+ if (n == 1) {
+ return batches[0];
+ }
+ int cols = batches[0]->num_columns();
+ auto schema = batches[0]->schema();
+ std::vector<std::shared_ptr<Array>> columns;
+ if (cols == 0) {
+ // special case: null batch, no data, just length
+ for (size_t i = 0; i < batches.size(); ++i) {
+ length += batches[i]->num_rows();
+ }
+ } else {
+ for (int col = 0; col < cols; ++col) {
+ ArrayVector data;
+ for (size_t i = 0; i < batches.size(); ++i) {
+ auto cur_schema = batches[i]->schema();
+ if (!schema->Equals(cur_schema)) {
+ return Status::Invalid(
+ "RecordBatch index ", i, " schema is ", cur_schema->ToString(),
+ ", did not match index 0 recordbatch schema: ",
schema->ToString());
+ }
+ auto column_data = batches[i]->column(col);
+ data.push_back(column_data);
+ }
+ auto array = Concatenate(data, pool).ValueOrDie();
Review Comment:
Instead of aborting, please raise the error. We have a helper macro for this:
```suggestion
ARROW_ASSIGN_OR_RAISE(auto array, Concatenate(data, pool));
```
##########
cpp/src/arrow/record_batch_test.cc:
##########
@@ -555,4 +555,42 @@ TEST_F(TestRecordBatch, ReplaceSchema) {
ASSERT_RAISES(Invalid, b1->ReplaceSchema(schema));
}
+TEST_F(TestRecordBatch, ConcatenateRecordBatches) {
+ int length = 10;
+
+ auto f0 = field("f0", int32());
+ auto f1 = field("f1", uint8());
+
+ auto schema = ::arrow::schema({f0, f1});
+
+ random::RandomArrayGenerator gen(42);
+
+ auto a0 = gen.ArrayOf(int32(), length);
Review Comment:
I think you could use `gen.BatchOf` instead for less boilerplate
##########
cpp/src/arrow/record_batch.cc:
##########
@@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close
failed");
}
+Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
+ const RecordBatchVector& batches, MemoryPool* pool) {
+ int64_t length = 0;
+ size_t n = batches.size();
+ if (n == 0) {
+ return Status::Invalid("Must pass at least one recordbatch");
+ }
+ if (n == 1) {
+ return batches[0];
+ }
+ int cols = batches[0]->num_columns();
+ auto schema = batches[0]->schema();
+ std::vector<std::shared_ptr<Array>> columns;
+ if (cols == 0) {
+ // special case: null batch, no data, just length
+ for (size_t i = 0; i < batches.size(); ++i) {
+ length += batches[i]->num_rows();
+ }
+ } else {
Review Comment:
I don't think you need a special case for zero columns:
```suggestion
// special case: null batch, no data, just length
for (size_t i = 0; i < batches.size(); ++i) {
length += batches[i]->num_rows();
}
```
Then the loop `for (int col = 0; col < cols; ++col)` will just never be
entered if `cols == 0`
##########
cpp/src/arrow/record_batch.cc:
##########
@@ -432,4 +433,43 @@ RecordBatchReader::~RecordBatchReader() {
ARROW_WARN_NOT_OK(this->Close(), "Implicitly called RecordBatchReader::Close
failed");
}
+Result<std::shared_ptr<RecordBatch>> ConcatenateRecordBatches(
+ const RecordBatchVector& batches, MemoryPool* pool) {
+ int64_t length = 0;
+ size_t n = batches.size();
+ if (n == 0) {
+ return Status::Invalid("Must pass at least one recordbatch");
+ }
+ if (n == 1) {
+ return batches[0];
+ }
+ int cols = batches[0]->num_columns();
+ auto schema = batches[0]->schema();
+ std::vector<std::shared_ptr<Array>> columns;
+ if (cols == 0) {
+ // special case: null batch, no data, just length
+ for (size_t i = 0; i < batches.size(); ++i) {
+ length += batches[i]->num_rows();
+ }
+ } else {
+ for (int col = 0; col < cols; ++col) {
+ ArrayVector data;
+ for (size_t i = 0; i < batches.size(); ++i) {
+ auto cur_schema = batches[i]->schema();
+ if (!schema->Equals(cur_schema)) {
+ return Status::Invalid(
+ "RecordBatch index ", i, " schema is ", cur_schema->ToString(),
+ ", did not match index 0 recordbatch schema: ",
schema->ToString());
+ }
+ auto column_data = batches[i]->column(col);
+ data.push_back(column_data);
+ }
+ auto array = Concatenate(data, pool).ValueOrDie();
+ length = array->length();
Review Comment:
Instead, please compute the length explicitly before this loop as in the
zero columns case
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]