save-buffer commented on code in PR #13669:
URL: https://github.com/apache/arrow/pull/13669#discussion_r972392589


##########
cpp/src/arrow/compute/exec/accumulation_queue.cc:
##########
@@ -48,11 +50,221 @@ void AccumulationQueue::InsertBatch(ExecBatch batch) {
   batches_.emplace_back(std::move(batch));
 }
 
+void AccumulationQueue::SetBatch(size_t idx, ExecBatch batch)
+{
+    ARROW_DCHECK(idx < batches_.size());
+    arrow::util::AtomicFetchSub(&row_count_, batches_[idx].length, 
std::memory_order_relaxed);
+    arrow::util::AtomicFetchAdd(&row_count_, batch.length, 
std::memory_order_relaxed);
+    batches_[idx] = std::move(batch);
+}
+
 void AccumulationQueue::Clear() {
   row_count_ = 0;
   batches_.clear();
 }
 
-ExecBatch& AccumulationQueue::operator[](size_t i) { return batches_[i]; }
+    Status SpillingAccumulationQueue::Init(QueryContext *ctx)
+    {
+        ctx_ = ctx;
+        partition_locks_.Init(ctx_->max_concurrency(), kNumPartitions);
+        return Status::OK();
+    }
+
+    Status SpillingAccumulationQueue::InsertBatch(
+        size_t thread_index,
+        ExecBatch batch)
+    {
+        Datum &hash_datum = batch.values.back();
+        const uint64_t *hashes = reinterpret_cast<const uint64_t 
*>(hash_datum.array()->buffers[1]->data());
+        // `permutation` stores the indices of rows in the input batch sorted 
by partition.
+        std::vector<uint16_t> permutation(batch.length);
+        uint16_t part_starts[kNumPartitions + 1];
+        PartitionSort::Eval(
+            batch.length,
+            kNumPartitions,
+            part_starts,
+            [&](int64_t i)
+            {
+                return hashes[i] & (kNumPartitions - 1);
+            },
+            [&permutation](int64_t input_pos, int64_t output_pos)
+            {
+                permutation[output_pos] = static_cast<uint16_t>(input_pos);
+            });
+
+        int unprocessed_partition_ids[kNumPartitions];
+        RETURN_NOT_OK(partition_locks_.ForEachPartition(
+                          thread_index,
+                          unprocessed_partition_ids,
+                          [&](int part_id)
+                          {
+                              return part_starts[part_id + 1] == 
part_starts[part_id];
+                          },
+                          [&](int locked_part_id_int)
+                          {
+                              size_t locked_part_id = 
static_cast<size_t>(locked_part_id_int);
+                              uint64_t num_total_rows_to_append =
+                                  part_starts[locked_part_id + 1] - 
part_starts[locked_part_id];
+
+                              size_t offset = 
static_cast<size_t>(part_starts[locked_part_id]);
+                              while(num_total_rows_to_append > 0)
+                              {
+                                  int num_rows_to_append = std::min(
+                                      
static_cast<int>(num_total_rows_to_append),
+                                      
static_cast<int>(ExecBatchBuilder::num_rows_max() - 
builders_[locked_part_id].num_rows()));
+
+                                  
RETURN_NOT_OK(builders_[locked_part_id].AppendSelected(
+                                                    ctx_->memory_pool(),
+                                                    batch,
+                                                    num_rows_to_append,
+                                                    permutation.data() + 
offset,
+                                                    batch.num_values()));
+
+                                  if(builders_[locked_part_id].is_full())
+                                  {
+                                      ExecBatch batch = 
builders_[locked_part_id].Flush();
+                                      Datum hash = 
std::move(batch.values.back());
+                                      batch.values.pop_back();
+                                      ExecBatch hash_batch({ std::move(hash) 
}, batch.length);
+                                      if(locked_part_id < spilling_cursor_)
+                                          
RETURN_NOT_OK(files_[locked_part_id].SpillBatch(
+                                                            ctx_,
+                                                            std::move(batch)));
+                                      else
+                                          
queues_[locked_part_id].InsertBatch(std::move(batch));
+
+                                      if(locked_part_id < hash_cursor_)
+                                          RETURN_NOT_OK(
+                                              
hash_files_[locked_part_id].SpillBatch(
+                                                  ctx_,
+                                                  std::move(hash_batch)));
+                                      else
+                                          
hash_queues_[locked_part_id].InsertBatch(std::move(hash_batch));
+
+                                  }
+                                  offset += num_rows_to_append;
+                                  num_total_rows_to_append -= 
num_rows_to_append;
+                              }
+                              return Status::OK();
+                          }));
+        return Status::OK();
+    }
+
+    const uint64_t *SpillingAccumulationQueue::GetHashes(size_t partition, 
size_t batch_idx)

Review Comment:
   Wrote a big comment above "BuildPartitionedBloomFilter"



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to