save-buffer commented on code in PR #13669:
URL: https://github.com/apache/arrow/pull/13669#discussion_r1068688882
##########
cpp/src/arrow/compute/exec/accumulation_queue.cc:
##########
@@ -39,20 +37,180 @@ void AccumulationQueue::Concatenate(AccumulationQueue&&
that) {
this->batches_.reserve(this->batches_.size() + that.batches_.size());
std::move(that.batches_.begin(), that.batches_.end(),
std::back_inserter(this->batches_));
- this->row_count_ += that.row_count_;
that.Clear();
}
void AccumulationQueue::InsertBatch(ExecBatch batch) {
- row_count_ += batch.length;
batches_.emplace_back(std::move(batch));
}
+void AccumulationQueue::SetBatch(size_t idx, ExecBatch batch) {
+ ARROW_DCHECK(idx < batches_.size());
+ batches_[idx] = std::move(batch);
+}
+
+size_t AccumulationQueue::CalculateRowCount() const {
+ size_t count = 0;
+ for (const ExecBatch& b : batches_) count += static_cast<size_t>(b.length);
+ return count;
+}
+
void AccumulationQueue::Clear() {
- row_count_ = 0;
batches_.clear();
}
-ExecBatch& AccumulationQueue::operator[](size_t i) { return batches_[i]; }
-} // namespace util
+Status SpillingAccumulationQueue::Init(QueryContext* ctx) {
+ ctx_ = ctx;
+ partition_locks_.Init(ctx_->max_concurrency(), kNumPartitions);
+ for (size_t ipart = 0; ipart < kNumPartitions; ipart++) {
+ task_group_read_[ipart] = ctx_->RegisterTaskGroup(
+ [this, ipart](size_t thread_index, int64_t batch_index) {
+ return read_back_fn_[ipart](thread_index,
static_cast<size_t>(batch_index),
+ std::move(queues_[ipart][batch_index]));
+ },
+ [this, ipart](size_t thread_index) { return
on_finished_[ipart](thread_index); });
+ }
+ return Status::OK();
+}
+
+Status SpillingAccumulationQueue::InsertBatch(size_t thread_index, ExecBatch
batch) {
+ Datum& hash_datum = batch.values.back();
+ const uint64_t* hashes =
+ reinterpret_cast<const
uint64_t*>(hash_datum.array()->buffers[1]->data());
+ // `permutation` stores the indices of rows in the input batch sorted by
partition.
+ std::vector<uint16_t> permutation(batch.length);
+ uint16_t part_starts[kNumPartitions + 1];
+ PartitionSort::Eval(
+ batch.length, kNumPartitions, part_starts,
+ /*partition_id=*/[&](int64_t i) { return partition_id(hashes[i]); },
+ /*output_fn=*/
+ [&permutation](int64_t input_pos, int64_t output_pos) {
+ permutation[output_pos] = static_cast<uint16_t>(input_pos);
+ });
+
+ int unprocessed_partition_ids[kNumPartitions];
+ RETURN_NOT_OK(partition_locks_.ForEachPartition(
+ thread_index, unprocessed_partition_ids,
+ /*is_prtn_empty=*/
+ [&](int part_id) { return part_starts[part_id + 1] ==
part_starts[part_id]; },
+ /*partition=*/
+ [&](int locked_part_id_int) {
+ size_t locked_part_id = static_cast<size_t>(locked_part_id_int);
+ uint64_t num_total_rows_to_append =
+ part_starts[locked_part_id + 1] - part_starts[locked_part_id];
+
+ size_t offset = static_cast<size_t>(part_starts[locked_part_id]);
+ while (num_total_rows_to_append > 0) {
+ int num_rows_to_append =
+ std::min(static_cast<int>(num_total_rows_to_append),
+ static_cast<int>(ExecBatchBuilder::num_rows_max() -
+ builders_[locked_part_id].num_rows()));
+
+ RETURN_NOT_OK(builders_[locked_part_id].AppendSelected(
+ ctx_->memory_pool(), batch, num_rows_to_append,
permutation.data() + offset,
+ batch.num_values()));
+
+ if (builders_[locked_part_id].is_full()) {
+ ExecBatch batch = builders_[locked_part_id].Flush();
+ Datum hash = std::move(batch.values.back());
+ batch.values.pop_back();
+ ExecBatch hash_batch({std::move(hash)}, batch.length);
+ if (locked_part_id < spilling_cursor_)
+ RETURN_NOT_OK(files_[locked_part_id].SpillBatch(ctx_,
std::move(batch)));
+ else
+ queues_[locked_part_id].InsertBatch(std::move(batch));
+
+ if (locked_part_id >= hash_cursor_)
+ hash_queues_[locked_part_id].InsertBatch(std::move(hash_batch));
+ }
+ offset += num_rows_to_append;
+ num_total_rows_to_append -= num_rows_to_append;
+ }
+ return Status::OK();
+ }));
+ return Status::OK();
+}
+
+const uint64_t* SpillingAccumulationQueue::GetHashes(size_t partition, size_t
batch_idx) {
+ ARROW_DCHECK(partition >= hash_cursor_.load());
+ if (batch_idx > hash_queues_[partition].batch_count()) {
+ const Datum& datum = hash_queues_[partition][batch_idx].values[0];
+ return reinterpret_cast<const
uint64_t*>(datum.array()->buffers[1]->data());
+ } else {
+ size_t hash_idx = builders_[partition].num_cols();
+ KeyColumnArray kca = builders_[partition].column(hash_idx - 1);
+ return reinterpret_cast<const uint64_t*>(kca.data(1));
+ }
+}
+
+Status SpillingAccumulationQueue::GetPartition(
+ size_t thread_index, size_t partition,
+ std::function<Status(size_t, size_t, ExecBatch)> on_batch,
+ std::function<Status(size_t)> on_finished) {
+ bool is_in_memory = partition >= spilling_cursor_.load();
+ if (builders_[partition].num_rows() > 0) {
+ ExecBatch batch = builders_[partition].Flush();
+ Datum hash = std::move(batch.values.back());
+ batch.values.pop_back();
+ if (is_in_memory) {
Review Comment:
Good point, I think this was an artifact of a previous iteration where I
needed them to be inserted. I changed it
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]