This is an automated email from the ASF dual-hosted git repository.

marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 4aa1ad4f1b [VL] Refine LocalPartitionWriter implementation (#9982)
4aa1ad4f1b is described below

commit 4aa1ad4f1b361f533838fbb43928f73c2d4f934f
Author: Rong Ma <[email protected]>
AuthorDate: Mon Jun 16 14:14:19 2025 +0100

    [VL] Refine LocalPartitionWriter implementation (#9982)
---
 cpp/core/shuffle/LocalPartitionWriter.cc | 187 ++++++++++++++++---------------
 cpp/core/shuffle/LocalPartitionWriter.h  |  11 +-
 cpp/core/shuffle/Spill.h                 |   4 +-
 3 files changed, 108 insertions(+), 94 deletions(-)

diff --git a/cpp/core/shuffle/LocalPartitionWriter.cc 
b/cpp/core/shuffle/LocalPartitionWriter.cc
index a8f6629af8..0a87c79fe8 100644
--- a/cpp/core/shuffle/LocalPartitionWriter.cc
+++ b/cpp/core/shuffle/LocalPartitionWriter.cc
@@ -69,21 +69,20 @@ class LocalPartitionWriter::LocalSpiller {
     }
   }
 
-  arrow::Status flush() {
-    if (flushed_) {
-      return arrow::Status::OK();
-    }
-    flushed_ = true;
+  arrow::Status spill(uint32_t partitionId, std::unique_ptr<BlockPayload> 
payload) {
+    ARROW_ASSIGN_OR_RAISE(auto start, os_->Tell());
+    RETURN_NOT_OK(payload->serialize(os_.get()));
 
-    if (compressedOs_ != nullptr) {
-      RETURN_NOT_OK(compressedOs_->Flush());
-    }
-    ARROW_ASSIGN_OR_RAISE(const auto pos, os_->Tell());
+    ARROW_ASSIGN_OR_RAISE(auto end, os_->Tell());
 
-    diskSpill_->insertPayload(lastPid_, Payload::kRaw, 0, nullptr, pos - 
writePos_, pool_, nullptr);
+    DLOG(INFO) << "LocalSpiller: Spilled partition " << partitionId << " file 
start: " << start << ", file end: " << end
+               << ", file: " << spillFile_;
 
-    DLOG(INFO) << "LocalSpiller: Spilled partition " << lastPid_ << " file 
start: " << writePos_
-               << ", file end: " << pos << ", file: " << spillFile_;
+    compressTime_ += payload->getCompressTime();
+    spillTime_ += payload->getWriteTime();
+
+    diskSpill_->insertPayload(
+        partitionId, payload->type(), payload->numRows(), 
payload->isValidityBuffer(), end - start, pool_, codec_);
 
     return arrow::Status::OK();
   }
@@ -91,33 +90,25 @@ class LocalPartitionWriter::LocalSpiller {
   arrow::Status spill(uint32_t partitionId, std::unique_ptr<InMemoryPayload> 
payload) {
     ScopedTimer timer(&spillTime_);
 
-    if (lastPid_ != partitionId) {
-      // Record the write position of the new partition.
-      ARROW_ASSIGN_OR_RAISE(writePos_, os_->Tell());
-      lastPid_ = partitionId;
-    }
-
+    curPid_ = partitionId;
     flushed_ = false;
+
     auto* raw = compressedOs_ != nullptr ? compressedOs_.get() : os_.get();
     RETURN_NOT_OK(payload->serialize(raw));
 
     return arrow::Status::OK();
   }
 
-  arrow::Status spill(uint32_t partitionId, std::unique_ptr<BlockPayload> 
payload) {
-    ARROW_ASSIGN_OR_RAISE(auto start, os_->Tell());
-    RETURN_NOT_OK(payload->serialize(os_.get()));
-
-    ARROW_ASSIGN_OR_RAISE(auto end, os_->Tell());
-
-    DLOG(INFO) << "LocalSpiller: Spilled partition " << partitionId << " file 
start: " << start << ", file end: " << end
-               << ", file: " << spillFile_;
-
-    compressTime_ += payload->getCompressTime();
-    spillTime_ += payload->getWriteTime();
+  arrow::Status flush() {
+    if (flushed_) {
+      return arrow::Status::OK();
+    }
+    flushed_ = true;
 
-    diskSpill_->insertPayload(
-        partitionId, payload->type(), payload->numRows(), 
payload->isValidityBuffer(), end - start, pool_, codec_);
+    if (compressedOs_ != nullptr) {
+      RETURN_NOT_OK(compressedOs_->Flush());
+    }
+    RETURN_NOT_OK(insertSpill());
 
     return arrow::Status::OK();
   }
@@ -126,19 +117,13 @@ class LocalPartitionWriter::LocalSpiller {
     ARROW_RETURN_IF(finished_, arrow::Status::Invalid("Calling finish() on a 
finished LocalSpiller."));
     ARROW_RETURN_IF(os_->closed(), arrow::Status::Invalid("Spill file os has 
been closed."));
 
-    if (lastPid_ != -1) {
+    if (curPid_ != -1) {
       if (compressedOs_ != nullptr) {
         compressTime_ = compressedOs_->compressTime();
         spillTime_ -= compressTime_;
         RETURN_NOT_OK(compressedOs_->Close());
       }
-
-      if (!isFinal_) {
-        ARROW_ASSIGN_OR_RAISE(auto pos, os_->Tell());
-        diskSpill_->insertPayload(lastPid_, Payload::kRaw, 0, nullptr, pos - 
writePos_, pool_, nullptr);
-        DLOG(INFO) << "LocalSpiller: Spilled partition " << lastPid_ << " file 
start: " << writePos_
-                   << ", file end: " << pos << ", file: " << spillFile_;
-      }
+      RETURN_NOT_OK(insertSpill());
     }
 
     if (!isFinal_) {
@@ -158,6 +143,18 @@ class LocalPartitionWriter::LocalSpiller {
   }
 
  private:
+  arrow::Status insertSpill() {
+    ARROW_ASSIGN_OR_RAISE(const auto pos, os_->Tell());
+    GLUTEN_DCHECK(pos >= writePos_, "Current write position should not be less 
than the last write position.");
+    if (pos > writePos_) {
+      diskSpill_->insertPayload(curPid_, Payload::kRaw, 0, nullptr, pos - 
writePos_, pool_, nullptr);
+      DLOG(INFO) << "LocalSpiller: Spilled partition " << curPid_ << " file 
start: " << writePos_
+                 << ", file end: " << pos << ", file: " << spillFile_;
+      writePos_ = pos;
+    }
+    return arrow::Status::OK();
+  }
+
   bool isFinal_;
 
   std::shared_ptr<arrow::io::OutputStream> os_;
@@ -174,7 +171,7 @@ class LocalPartitionWriter::LocalSpiller {
   bool finished_{false};
   int64_t spillTime_{0};
   int64_t compressTime_{0};
-  int32_t lastPid_{-1};
+  int32_t curPid_{-1};
 };
 
 class LocalPartitionWriter::PayloadMerger {
@@ -429,9 +426,30 @@ std::string LocalPartitionWriter::nextSpilledFileDir() {
 }
 
 arrow::Status LocalPartitionWriter::clearResource() {
-  RETURN_NOT_OK(dataFileOs_->Close());
-  // When bufferedWrite = true, dataFileOs_->Close doesn't release underlying 
buffer.
-  dataFileOs_.reset();
+  if (dataFileOs_ != nullptr) {
+    RETURN_NOT_OK(dataFileOs_->Close());
+    // When bufferedWrite = true, dataFileOs_->Close doesn't release 
underlying buffer.
+    dataFileOs_.reset();
+  }
+
+  // Check all spills are merged.
+  size_t spillId = 0;
+  for (const auto& spill : spills_) {
+    compressTime_ += spill->compressTime();
+    spillTime_ += spill->spillTime();
+    for (auto pid = 0; pid < numPartitions_; ++pid) {
+      if (spill->hasNextPayload(pid)) {
+        return arrow::Status::Invalid(
+            "Merging from spill " + std::to_string(spillId) + " is not 
exhausted. pid: " + std::to_string(pid));
+      }
+    }
+    if (std::filesystem::exists(spill->spillFile()) && 
!std::filesystem::remove(spill->spillFile())) {
+      LOG(WARNING) << "Error while deleting spill file " << spill->spillFile();
+    }
+    ++spillId;
+  }
+  spills_.clear();
+
   return arrow::Status::OK();
 }
 
@@ -447,23 +465,23 @@ void LocalPartitionWriter::init() {
   subDirSelection_.assign(localDirs_.size(), 0);
 }
 
-arrow::Result<int64_t> LocalPartitionWriter::mergeSpills(uint32_t partitionId) 
{
+arrow::Result<int64_t> LocalPartitionWriter::mergeSpills(uint32_t partitionId, 
arrow::io::OutputStream* os) {
   int64_t bytesEvicted = 0;
   int32_t spillIndex = 0;
 
   for (const auto& spill : spills_) {
-    ARROW_ASSIGN_OR_RAISE(auto startPos, dataFileOs_->Tell());
+    ARROW_ASSIGN_OR_RAISE(auto startPos, os->Tell());
 
     spill->openForRead(options_.shuffleFileBufferSize);
 
     // Read if partition exists in the spilled file. Then write to the final 
data file.
     while (auto payload = spill->nextPayload(partitionId)) {
-      RETURN_NOT_OK(payload->serialize(dataFileOs_.get()));
+      RETURN_NOT_OK(payload->serialize(os));
       compressTime_ += payload->getCompressTime();
       writeTime_ += payload->getWriteTime();
     }
 
-    ARROW_ASSIGN_OR_RAISE(auto endPos, dataFileOs_->Tell());
+    ARROW_ASSIGN_OR_RAISE(auto endPos, os->Tell());
     auto bytesWritten = endPos - startPos;
 
     DLOG(INFO) << "Partition " << partitionId << " spilled from spillResult " 
<< spillIndex++ << " of bytes "
@@ -476,6 +494,13 @@ arrow::Result<int64_t> 
LocalPartitionWriter::mergeSpills(uint32_t partitionId) {
   return bytesEvicted;
 }
 
+arrow::Status LocalPartitionWriter::writeCachedPayloads(uint32_t partitionId, 
arrow::io::OutputStream* os) const {
+  if (payloadCache_ != nullptr) {
+    RETURN_NOT_OK(payloadCache_->write(partitionId, os));
+  }
+  return arrow::Status::OK();
+}
+
 arrow::Status LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics) {
   if (stopped_) {
     return arrow::Status::OK();
@@ -483,13 +508,12 @@ arrow::Status 
LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics) {
   stopped_ = true;
 
   if (useSpillFileAsDataFile_) {
-    RETURN_NOT_OK(spiller_->flush());
     ARROW_ASSIGN_OR_RAISE(auto spill, spiller_->finish());
 
     // Merge the remaining partitions from spills.
     if (!spills_.empty()) {
       for (auto pid = lastEvictPid_ + 1; pid < numPartitions_; ++pid) {
-        ARROW_ASSIGN_OR_RAISE(partitionLengths_[pid], mergeSpills(pid));
+        ARROW_ASSIGN_OR_RAISE(partitionLengths_[pid], mergeSpills(pid, 
dataFileOs_.get()));
       }
     }
 
@@ -502,22 +526,7 @@ arrow::Status 
LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics) {
     compressTime_ += spill->compressTime();
   } else {
     RETURN_NOT_OK(finishSpill());
-
-    if (merger_) {
-      for (auto pid = 0; pid < numPartitions_; ++pid) {
-        ARROW_ASSIGN_OR_RAISE(auto maybeMerged, merger_->finish(pid, false));
-        if (maybeMerged.has_value()) {
-          if (!payloadCache_) {
-            payloadCache_ = std::make_shared<PayloadCache>(
-                numPartitions_, codec_.get(), options_.compressionThreshold, 
payloadPool_.get());
-          }
-          // Spill can be triggered by compressing or building dictionaries.
-          RETURN_NOT_OK(payloadCache_->cache(pid, 
std::move(maybeMerged.value())));
-        }
-      }
-
-      merger_.reset();
-    }
+    RETURN_NOT_OK(finishMerger());
 
     ARROW_ASSIGN_OR_RAISE(dataFileOs_, openFile(dataFile_, 
options_.shuffleFileBufferSize));
 
@@ -529,35 +538,17 @@ arrow::Status 
LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics) {
       auto startInFinalFile = endInFinalFile;
       // Iterator over all spilled files.
       // May trigger spill during compression.
-      RETURN_NOT_OK(mergeSpills(pid));
-
-      if (payloadCache_) {
-        RETURN_NOT_OK(payloadCache_->write(pid, dataFileOs_.get()));
-      }
+      RETURN_NOT_OK(mergeSpills(pid, dataFileOs_.get()));
+      RETURN_NOT_OK(writeCachedPayloads(pid, dataFileOs_.get()));
 
       ARROW_ASSIGN_OR_RAISE(endInFinalFile, dataFileOs_->Tell());
       partitionLengths_[pid] = endInFinalFile - startInFinalFile;
     }
   }
+  ARROW_ASSIGN_OR_RAISE(totalBytesWritten_, dataFileOs_->Tell());
+
   // Close Final file. Clear buffered resources.
   RETURN_NOT_OK(clearResource());
-  // Check all spills are merged.
-  auto s = 0;
-  for (const auto& spill : spills_) {
-    compressTime_ += spill->compressTime();
-    spillTime_ += spill->spillTime();
-    for (auto pid = 0; pid < numPartitions_; ++pid) {
-      if (spill->hasNextPayload(pid)) {
-        return arrow::Status::Invalid(
-            "Merging from spill " + std::to_string(s) + " is not exhausted. 
pid: " + std::to_string(pid));
-      }
-    }
-    if (std::filesystem::exists(spill->spillFile()) && 
!std::filesystem::remove(spill->spillFile())) {
-      LOG(WARNING) << "Error while deleting spill file " << spill->spillFile();
-    }
-    ++s;
-  }
-  spills_.clear();
 
   // Populate shuffle writer metrics.
   RETURN_NOT_OK(populateMetrics(metrics));
@@ -593,6 +584,24 @@ arrow::Status LocalPartitionWriter::finishSpill() {
   return arrow::Status::OK();
 }
 
+arrow::Status LocalPartitionWriter::finishMerger() {
+  if (merger_ != nullptr) {
+    for (auto pid = 0; pid < numPartitions_; ++pid) {
+      ARROW_ASSIGN_OR_RAISE(auto maybeMerged, merger_->finish(pid, false));
+      if (maybeMerged.has_value()) {
+        if (payloadCache_ == nullptr) {
+          payloadCache_ = std::make_shared<PayloadCache>(
+              numPartitions_, codec_.get(), options_.compressionThreshold, 
payloadPool_.get());
+        }
+        // Spill can be triggered by compressing or building dictionaries.
+        RETURN_NOT_OK(payloadCache_->cache(pid, 
std::move(maybeMerged.value())));
+      }
+    }
+    merger_.reset();
+  }
+  return arrow::Status::OK();
+}
+
 arrow::Status LocalPartitionWriter::hashEvict(
     uint32_t partitionId,
     std::unique_ptr<InMemoryPayload> inMemoryPayload,
@@ -649,7 +658,7 @@ LocalPartitionWriter::sortEvict(uint32_t partitionId, 
std::unique_ptr<InMemoryPa
     // are not merged here and will be merged in `stop()`.
     if (isFinal && !spills_.empty()) {
       for (auto pid = lastEvictPid_ + 1; pid <= partitionId; ++pid) {
-        ARROW_ASSIGN_OR_RAISE(partitionLengths_[pid], mergeSpills(pid));
+        ARROW_ASSIGN_OR_RAISE(partitionLengths_[pid], mergeSpills(pid, 
dataFileOs_.get()));
       }
     }
   }
@@ -725,7 +734,7 @@ arrow::Status 
LocalPartitionWriter::populateMetrics(ShuffleWriterMetrics* metric
   metrics->totalWriteTime += writeTime_;
   metrics->totalBytesToEvict += totalBytesToEvict_;
   metrics->totalBytesEvicted += totalBytesEvicted_;
-  metrics->totalBytesWritten += std::filesystem::file_size(dataFile_);
+  metrics->totalBytesWritten += totalBytesWritten_;
   metrics->partitionLengths = std::move(partitionLengths_);
   metrics->rawPartitionLengths = std::move(rawPartitionLengths_);
   return arrow::Status::OK();
diff --git a/cpp/core/shuffle/LocalPartitionWriter.h 
b/cpp/core/shuffle/LocalPartitionWriter.h
index 594a58f89b..8d7e4a2a0f 100644
--- a/cpp/core/shuffle/LocalPartitionWriter.h
+++ b/cpp/core/shuffle/LocalPartitionWriter.h
@@ -75,22 +75,26 @@ class LocalPartitionWriter : public PartitionWriter {
   // 3. After stop() called,
   arrow::Status reclaimFixedSize(int64_t size, int64_t* actual) override;
 
+ protected:
   class LocalSpiller;
 
   class PayloadMerger;
 
   class PayloadCache;
 
- private:
   void init();
 
   arrow::Status requestSpill(bool isFinal);
 
   arrow::Status finishSpill();
 
+  arrow::Status finishMerger();
+
   std::string nextSpilledFileDir();
 
-  arrow::Result<int64_t> mergeSpills(uint32_t partitionId);
+  arrow::Result<int64_t> mergeSpills(uint32_t partitionId, 
arrow::io::OutputStream* os);
+
+  arrow::Status writeCachedPayloads(uint32_t partitionId, 
arrow::io::OutputStream* os) const;
 
   arrow::Status clearResource();
 
@@ -109,10 +113,11 @@ class LocalPartitionWriter : public PartitionWriter {
   // configured local dirs for spilled file
   int32_t dirSelection_{0};
   std::vector<int32_t> subDirSelection_;
-  std::shared_ptr<arrow::io::OutputStream> dataFileOs_;
+  std::shared_ptr<arrow::io::OutputStream> dataFileOs_{nullptr};
 
   int64_t totalBytesToEvict_{0};
   int64_t totalBytesEvicted_{0};
+  int64_t totalBytesWritten_{0};
   std::vector<int64_t> partitionLengths_;
   std::vector<int64_t> rawPartitionLengths_;
 
diff --git a/cpp/core/shuffle/Spill.h b/cpp/core/shuffle/Spill.h
index 1bf55152ad..e6e83b0f4e 100644
--- a/cpp/core/shuffle/Spill.h
+++ b/cpp/core/shuffle/Spill.h
@@ -67,8 +67,8 @@ class Spill final {
   std::shared_ptr<gluten::MmapFileStream> is_;
   std::list<PartitionPayload> partitionPayloads_{};
   std::string spillFile_;
-  int64_t spillTime_;
-  int64_t compressTime_;
+  int64_t spillTime_{0};
+  int64_t compressTime_{0};
 
   arrow::io::InputStream* rawIs_{nullptr};
 };


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to