westonpace commented on a change in pull request #11556:
URL: https://github.com/apache/arrow/pull/11556#discussion_r741557476



##########
File path: cpp/src/arrow/dataset/dataset_writer.cc
##########
@@ -83,128 +84,163 @@ class Throttle {
   std::mutex mutex_;
 };
 
+struct DatasetWriterState {
+  DatasetWriterState(uint64_t rows_in_flight, uint64_t max_open_files,
+                     uint64_t max_rows_staged)
+      : rows_in_flight_throttle(rows_in_flight),
+        open_files_throttle(max_open_files),
+        staged_rows_count(0),
+        max_rows_staged(max_rows_staged) {}
+
+  bool StagingFull() const { return staged_rows_count.load() >= 
max_rows_staged; }
+
+  // Throttle for how many rows the dataset writer will allow to be in process 
memory
+  // When this is exceeded the dataset writer will pause / apply backpressure
+  Throttle rows_in_flight_throttle;
+  // Control for how many files the dataset writer will open.  When this is 
exceeded
+  // the dataset writer will pause and it will also close the largest open 
file.
+  Throttle open_files_throttle;
+  // Control for how many rows the dataset writer will allow to be staged.  A 
row is
+  // staged if it is waiting for more rows to reach minimum_batch_size.  If 
this is
+  // exceeded then the largest staged batch is unstaged (no backpressure is 
applied)
+  std::atomic<uint64_t> staged_rows_count;
+  // If too many rows get staged we will end up with poor performance and, if 
more rows
+  // are staged than max_rows_queued we will end up with deadlock.  To avoid 
this, once
+  // we have too many staged rows we just ignore min_rows_per_group
+  const uint64_t max_rows_staged;
+  // Mutex to guard access to the file visitors in the writer options
+  std::mutex visitors_mutex;
+};
+
 class DatasetWriterFileQueue : public util::AsyncDestroyable {
  public:
   explicit DatasetWriterFileQueue(const Future<std::shared_ptr<FileWriter>>& 
writer_fut,
                                   const FileSystemDatasetWriteOptions& options,
-                                  std::mutex* visitors_mutex)
-      : options_(options), visitors_mutex_(visitors_mutex) {
-    running_task_ = Future<>::Make();
-    writer_fut.AddCallback(
-        [this](const Result<std::shared_ptr<FileWriter>>& maybe_writer) {
-          if (maybe_writer.ok()) {
-            writer_ = *maybe_writer;
-            Flush();
-          } else {
-            Abort(maybe_writer.status());
-          }
-        });
+                                  DatasetWriterState* writer_state)
+      : options_(options), writer_state_(writer_state) {
+    // If this AddTask call fails (e.g. we're given an already failing future) 
then we
+    // will get the error later when we try and write to it.
+    ARROW_UNUSED(file_tasks_.AddTask([this, writer_fut] {
+      return writer_fut.Then(
+          [this](const std::shared_ptr<FileWriter>& writer) { writer_ = 
writer; });
+    }));
   }
 
-  Future<uint64_t> Push(std::shared_ptr<RecordBatch> batch) {
-    std::unique_lock<std::mutex> lk(mutex);
-    write_queue_.push_back(std::move(batch));
-    Future<uint64_t> write_future = Future<uint64_t>::Make();
-    write_futures_.push_back(write_future);
-    if (!running_task_.is_valid()) {
-      running_task_ = Future<>::Make();
-      FlushUnlocked(std::move(lk));
+  Result<std::shared_ptr<RecordBatch>> PopStagedBatch() {
+    std::vector<std::shared_ptr<RecordBatch>> batches_to_write;
+    uint64_t num_rows = 0;
+    while (!staged_batches_.empty()) {
+      std::shared_ptr<RecordBatch> next = std::move(staged_batches_.front());
+      staged_batches_.pop_front();
+      if (num_rows + next->num_rows() <= options_.max_rows_per_group) {
+        num_rows += next->num_rows();
+        batches_to_write.push_back(std::move(next));
+        if (num_rows == options_.max_rows_per_group) {
+          break;
+        }
+      } else {
+        uint64_t remaining = options_.max_rows_per_group - num_rows;
+        std::shared_ptr<RecordBatch> next_partial =
+            next->Slice(0, static_cast<int64_t>(remaining));
+        batches_to_write.push_back(std::move(next_partial));
+        std::shared_ptr<RecordBatch> next_remainder =
+            next->Slice(static_cast<int64_t>(remaining));
+        staged_batches_.push_front(std::move(next_remainder));
+        break;
+      }
     }
-    return write_future;
+    DCHECK_GT(batches_to_write.size(), 0);
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Table> table,
+                          Table::FromRecordBatches(batches_to_write));
+    return table->CombineChunksToBatch();
+  }
+
+  Status ScheduleBatch(std::shared_ptr<RecordBatch> batch) {
+    struct WriteTask {
+      Future<> operator()() { return self->WriteNext(std::move(batch)); }
+      DatasetWriterFileQueue* self;
+      std::shared_ptr<RecordBatch> batch;
+    };
+    return file_tasks_.AddTask(WriteTask{this, std::move(batch)});
+  }
+
+  Result<int64_t> PopAndDeliverStagedBatch() {
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<RecordBatch> next_batch, 
PopStagedBatch());
+    int64_t rows_popped = next_batch->num_rows();
+    rows_currently_staged_ -= next_batch->num_rows();
+    ARROW_RETURN_NOT_OK(ScheduleBatch(std::move(next_batch)));
+    return rows_popped;
+  }
+
+  // Stage batches, popping and delivering batches if enough data has arrived
+  Status Push(std::shared_ptr<RecordBatch> batch) {
+    uint64_t delta_staged = batch->num_rows();
+    rows_currently_staged_ += delta_staged;
+    staged_batches_.push_back(std::move(batch));
+    while (!staged_batches_.empty() &&
+           (writer_state_->StagingFull() ||
+            rows_currently_staged_ >= options_.min_rows_per_group)) {
+      ARROW_ASSIGN_OR_RAISE(int64_t rows_popped, PopAndDeliverStagedBatch());
+      delta_staged -= rows_popped;
+    }
+    // Note, delta_staged may be negative if we were able to deliver some data
+    writer_state_->staged_rows_count += delta_staged;
+    return Status::OK();
   }
 
   Future<> DoDestroy() override {
-    std::lock_guard<std::mutex> lg(mutex);
-    if (!running_task_.is_valid()) {
-      RETURN_NOT_OK(DoFinish());
-      return Future<>::MakeFinished();
+    if (!aborted_) {
+      writer_state_->staged_rows_count -= rows_currently_staged_;
+      while (!staged_batches_.empty()) {
+        RETURN_NOT_OK(PopAndDeliverStagedBatch());
+      }
     }
-    return running_task_.Then([this] { return DoFinish(); });
+    return file_tasks_.End().Then([this] { return DoFinish(); });
   }
 
  private:
-  Future<uint64_t> WriteNext() {
+  Future<> WriteNext(std::shared_ptr<RecordBatch> next) {
+    struct WriteTask {
+      Status operator()() {
+        int64_t rows_to_release = batch->num_rows();
+        Status status = self->writer_->Write(batch);
+        self->writer_state_->rows_in_flight_throttle.Release(rows_to_release);
+        return status;
+      }
+      DatasetWriterFileQueue* self;
+      std::shared_ptr<RecordBatch> batch;
+    };
     // May want to prototype / measure someday pushing the async write down 
further
     return DeferNotOk(
-        io::default_io_context().executor()->Submit([this]() -> 
Result<uint64_t> {
-          DCHECK(running_task_.is_valid());
-          std::unique_lock<std::mutex> lk(mutex);
-          const std::shared_ptr<RecordBatch>& to_write = write_queue_.front();
-          Future<uint64_t> on_complete = write_futures_.front();
-          uint64_t rows_to_write = to_write->num_rows();
-          lk.unlock();
-          Status status = writer_->Write(to_write);
-          lk.lock();
-          write_queue_.pop_front();
-          write_futures_.pop_front();
-          lk.unlock();
-          if (!status.ok()) {
-            on_complete.MarkFinished(status);
-          } else {
-            on_complete.MarkFinished(rows_to_write);
-          }
-          return rows_to_write;
-        }));
+        io::default_io_context().executor()->Submit(WriteTask{this, 
std::move(next)}));
   }
 
   Status DoFinish() {
     {
-      std::lock_guard<std::mutex> lg(*visitors_mutex_);
+      std::lock_guard<std::mutex> lg(writer_state_->visitors_mutex);
       RETURN_NOT_OK(options_.writer_pre_finish(writer_.get()));
     }
     RETURN_NOT_OK(writer_->Finish());
     {
-      std::lock_guard<std::mutex> lg(*visitors_mutex_);
+      std::lock_guard<std::mutex> lg(writer_state_->visitors_mutex);
       return options_.writer_post_finish(writer_.get());
     }
   }
 
   void Abort(Status err) {
-    std::vector<Future<uint64_t>> futures_to_abort;
-    Future<> old_running_task = running_task_;
-    {
-      std::lock_guard<std::mutex> lg(mutex);
-      write_queue_.clear();
-      futures_to_abort =
-          std::vector<Future<uint64_t>>(write_futures_.begin(), 
write_futures_.end());
-      write_futures_.clear();
-      running_task_ = Future<>();
-    }
-    for (auto& fut : futures_to_abort) {
-      fut.MarkFinished(err);
-    }
-    old_running_task.MarkFinished(std::move(err));
-  }
-
-  void Flush() {
-    std::unique_lock<std::mutex> lk(mutex);
-    FlushUnlocked(std::move(lk));
-  }
-
-  void FlushUnlocked(std::unique_lock<std::mutex> lk) {
-    if (write_queue_.empty()) {
-      Future<> old_running_task = running_task_;
-      running_task_ = Future<>();
-      lk.unlock();
-      old_running_task.MarkFinished();
-      return;
-    }
-    WriteNext().AddCallback([this](const Result<uint64_t>& res) {
-      if (res.ok()) {
-        Flush();
-      } else {
-        Abort(res.status());
-      }
-    });
+    aborted_ = true;

Review comment:
       Hmm, true, but that isn't necessarily a good thing.  I suppose the 
various task groups should probably auto-abort once they receive a failed 
status.  This would match the behavior of the synchronous task groups.  At the 
moment this behavior is probably ok (if we fail a write call we will just keep 
trying to write to the file) but it would be cleaner to bail out sooner.  I've 
added ARROW-14565 to address in a follow-up as it has implications elsewhere.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to