This is an automated email from the ASF dual-hosted git repository.

taiyangli pushed a commit to branch revert-6432-columnar-shuffle-writer
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git

commit 6c51714ece2a3b8c61b2fef54c9d08a7dbee6894
Author: 李扬 <[email protected]>
AuthorDate: Mon Jul 15 13:36:32 2024 +0800

    Revert "[CELEBORN] CHCelebornColumnarShuffleWriter supports 
celeborn.client.s…"
    
    This reverts commit 30bddd0c9ba57c828202283a48349b6d1f11b230.
---
 .../shuffle/CHCelebornColumnarShuffleWriter.scala  | 88 +++++++++++----------
 .../shuffle/CelebornColumnarShuffleWriter.scala    | 36 ++-------
 .../VeloxCelebornColumnarShuffleWriter.scala       | 91 ++++++++++++++--------
 3 files changed, 112 insertions(+), 103 deletions(-)

diff --git 
a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
index e5cd3d22f..7276e0f2c 100644
--- 
a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarShuffleWriter.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 import org.apache.celeborn.client.ShuffleClient
 import org.apache.celeborn.common.CelebornConf
-import org.apache.celeborn.common.protocol.ShuffleMode
 
 import java.io.IOException
 import java.util.Locale
@@ -56,16 +55,61 @@ class CHCelebornColumnarShuffleWriter[K, V](
 
   private var splitResult: CHSplitResult = _
 
+  private val nativeBufferSize: Int = 
GlutenConfig.getConf.shuffleWriterBufferSize
+
   @throws[IOException]
   override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
+    if (!records.hasNext) {
+      handleEmptyIterator()
+      return
+    }
+
+    if (nativeShuffleWriter == -1L) {
+      nativeShuffleWriter = jniWrapper.makeForRSS(
+        dep.nativePartitioning,
+        shuffleId,
+        mapId,
+        nativeBufferSize,
+        customizedCompressCodec,
+        GlutenConfig.getConf.chColumnarShuffleSpillThreshold,
+        CHBackendSettings.shuffleHashAlgorithm,
+        celebornPartitionPusher,
+        GlutenConfig.getConf.chColumnarThrowIfMemoryExceed,
+        GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict,
+        GlutenConfig.getConf.chColumnarForceExternalSortShuffle,
+        GlutenConfig.getConf.chColumnarForceMemorySortShuffle
+      )
+      CHNativeMemoryAllocators.createSpillable(
+        "CelebornShuffleWriter",
+        new Spiller() {
+          override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+            if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+              return 0L
+            }
+            if (nativeShuffleWriter == -1L) {
+              throw new IllegalStateException(
+                "Fatal: spill() called before a celeborn shuffle writer " +
+                  "is created. This behavior should be" +
+                  "optimized by moving memory " +
+                  "allocations from make() to split()")
+            }
+            logInfo(s"Gluten shuffle writer: Trying to push $size bytes of 
data")
+            val spilled = jniWrapper.evict(nativeShuffleWriter)
+            logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of 
data")
+            spilled
+          }
+        }
+      )
+    }
     while (records.hasNext) {
       val cb = records.next()._2.asInstanceOf[ColumnarBatch]
       if (cb.numRows == 0 || cb.numCols == 0) {
         logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} 
cols")
       } else {
-        initShuffleWriter()
         val col = cb.column(0).asInstanceOf[CHColumnVector]
-        jniWrapper.split(nativeShuffleWriter, col.getBlockAddress)
+        val block = col.getBlockAddress
+        jniWrapper
+          .split(nativeShuffleWriter, block)
         dep.metrics("numInputRows").add(cb.numRows)
         dep.metrics("inputBatches").add(1)
         // This metric is important, AQE use it to decide if EliminateLimit
@@ -73,7 +117,6 @@ class CHCelebornColumnarShuffleWriter[K, V](
       }
     }
 
-    assert(nativeShuffleWriter != -1L)
     splitResult = jniWrapper.stop(nativeShuffleWriter)
 
     dep.metrics("splitTime").add(splitResult.getSplitTime)
@@ -92,43 +135,6 @@ class CHCelebornColumnarShuffleWriter[K, V](
     mapStatus = MapStatus(blockManager.shuffleServerId, 
splitResult.getRawPartitionLengths, mapId)
   }
 
-  override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
-    nativeShuffleWriter = jniWrapper.makeForRSS(
-      dep.nativePartitioning,
-      shuffleId,
-      mapId,
-      nativeBufferSize,
-      customizedCompressCodec,
-      GlutenConfig.getConf.chColumnarShuffleSpillThreshold,
-      CHBackendSettings.shuffleHashAlgorithm,
-      celebornPartitionPusher,
-      GlutenConfig.getConf.chColumnarThrowIfMemoryExceed,
-      GlutenConfig.getConf.chColumnarFlushBlockBufferBeforeEvict,
-      GlutenConfig.getConf.chColumnarForceExternalSortShuffle,
-      GlutenConfig.getConf.chColumnarForceMemorySortShuffle
-        || ShuffleMode.SORT.name.equalsIgnoreCase(shuffleWriterType)
-    )
-    CHNativeMemoryAllocators.createSpillable(
-      "CelebornShuffleWriter",
-      new Spiller() {
-        override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
-          if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
-            return 0L
-          }
-          if (nativeShuffleWriter == -1L) {
-            throw new IllegalStateException(
-              "Fatal: spill() called before a celeborn shuffle writer is 
created. " +
-                "This behavior should be optimized by moving memory 
allocations from make() to split()")
-          }
-          logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
-          val spilled = jniWrapper.evict(nativeShuffleWriter)
-          logInfo(s"Gluten shuffle writer: Spilled $spilled / $size bytes of 
data")
-          spilled
-        }
-      }
-    )
-  }
-
   override def closeShuffleWriter(): Unit = {
     jniWrapper.close(nativeShuffleWriter)
   }
diff --git 
a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
index f5ed8c3d8..d58eeb195 100644
--- 
a/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/common/src/main/scala/org/apache/spark/shuffle/CelebornColumnarShuffleWriter.scala
@@ -23,7 +23,6 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.SHUFFLE_COMPRESS
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.celeborn.CelebornShuffleHandle
-import org.apache.spark.sql.vectorized.ColumnarBatch
 import org.apache.spark.storage.BlockManager
 
 import org.apache.celeborn.client.ShuffleClient
@@ -53,23 +52,12 @@ abstract class CelebornColumnarShuffleWriter[K, V](
 
   protected val mapId: Int = context.partitionId()
 
-  protected lazy val nativeBufferSize: Int = {
-    val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize
-    val maxBatchSize = GlutenConfig.getConf.maxBatchSize
-    if (bufferSize > maxBatchSize) {
-      logInfo(
-        s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds 
max " +
-          s" batch size. Limited to 
${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).")
-      maxBatchSize
-    } else {
-      bufferSize
-    }
-  }
-
   protected val clientPushBufferMaxSize: Int = 
celebornConf.clientPushBufferMaxSize
 
   protected val clientPushSortMemoryThreshold: Long = 
celebornConf.clientPushSortMemoryThreshold
 
+  protected val clientSortMemoryMaxSize: Long = 
celebornConf.clientPushSortMemoryThreshold
+
   protected val shuffleWriterType: String =
     celebornConf.shuffleWriterMode.name.toLowerCase(Locale.ROOT)
 
@@ -108,12 +96,6 @@ abstract class CelebornColumnarShuffleWriter[K, V](
 
   @throws[IOException]
   final override def write(records: Iterator[Product2[K, V]]): Unit = {
-    if (!records.hasNext) {
-      partitionLengths = new Array[Long](dep.partitioner.numPartitions)
-      client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
-      mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
-      return
-    }
     internalWrite(records)
   }
 
@@ -140,18 +122,10 @@ abstract class CelebornColumnarShuffleWriter[K, V](
     }
   }
 
-  def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {}
-
   def closeShuffleWriter(): Unit = {}
 
   def getPartitionLengths: Array[Long] = partitionLengths
 
-  def initShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
-    if (nativeShuffleWriter == -1L) {
-      createShuffleWriter(columnarBatch)
-    }
-  }
-
   def pushMergedDataToCeleborn(): Unit = {
     val pushMergedDataTime = System.nanoTime
     client.prepareForMergeData(shuffleId, mapId, context.attemptNumber())
@@ -159,4 +133,10 @@ abstract class CelebornColumnarShuffleWriter[K, V](
     client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
     writeMetrics.incWriteTime(System.nanoTime - pushMergedDataTime)
   }
+
+  def handleEmptyIterator(): Unit = {
+    partitionLengths = new Array[Long](dep.partitioner.numPartitions)
+    client.mapperEnd(shuffleId, mapId, context.attemptNumber, numMappers)
+    mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
+  }
 }
diff --git 
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
 
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
index baf61b8a1..c93255eaa 100644
--- 
a/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
+++ 
b/gluten-celeborn/velox/src/main/scala/org/apache/spark/shuffle/VeloxCelebornColumnarShuffleWriter.scala
@@ -55,6 +55,25 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
 
   private var splitResult: SplitResult = _
 
+  private lazy val nativeBufferSize = {
+    val bufferSize = GlutenConfig.getConf.shuffleWriterBufferSize
+    val maxBatchSize = GlutenConfig.getConf.maxBatchSize
+    if (bufferSize > maxBatchSize) {
+      logInfo(
+        s"${GlutenConfig.SHUFFLE_WRITER_BUFFER_SIZE.key} ($bufferSize) exceeds 
max " +
+          s" batch size. Limited to 
${GlutenConfig.COLUMNAR_MAX_BATCH_SIZE.key} ($maxBatchSize).")
+      maxBatchSize
+    } else {
+      bufferSize
+    }
+  }
+
+  private val memoryLimit: Long = if ("sort".equals(shuffleWriterType)) {
+    Math.min(clientSortMemoryMaxSize, clientPushBufferMaxSize * numPartitions)
+  } else {
+    availableOffHeapPerTask()
+  }
+
   private def availableOffHeapPerTask(): Long = {
     val perTask =
       SparkMemoryUtil.getCurrentAvailableOffHeapMemory / 
SparkResourceUtil.getTaskSlots(conf)
@@ -63,13 +82,49 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
 
   @throws[IOException]
   override def internalWrite(records: Iterator[Product2[K, V]]): Unit = {
+    if (!records.hasNext) {
+      handleEmptyIterator()
+      return
+    }
+
     while (records.hasNext) {
       val cb = records.next()._2.asInstanceOf[ColumnarBatch]
       if (cb.numRows == 0 || cb.numCols == 0) {
         logInfo(s"Skip ColumnarBatch of ${cb.numRows} rows, ${cb.numCols} 
cols")
       } else {
-        initShuffleWriter(cb)
         val handle = ColumnarBatches.getNativeHandle(cb)
+        if (nativeShuffleWriter == -1L) {
+          nativeShuffleWriter = jniWrapper.makeForRSS(
+            dep.nativePartitioning,
+            nativeBufferSize,
+            customizedCompressionCodec,
+            compressionLevel,
+            bufferCompressThreshold,
+            GlutenConfig.getConf.columnarShuffleCompressionMode,
+            clientPushBufferMaxSize,
+            clientPushSortMemoryThreshold,
+            celebornPartitionPusher,
+            handle,
+            context.taskAttemptId(),
+            GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, 
context.partitionId),
+            "celeborn",
+            shuffleWriterType,
+            GlutenConfig.getConf.columnarShuffleReallocThreshold
+          )
+          runtime.addSpiller(new Spiller() {
+            override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
+              if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
+                return 0L
+              }
+              logInfo(s"Gluten shuffle writer: Trying to push $size bytes of 
data")
+              // fixme pass true when being called by self
+              val pushed =
+                jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
+              logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of 
data")
+              pushed
+            }
+          })
+        }
         val startTime = System.nanoTime()
         jniWrapper.write(nativeShuffleWriter, cb.numRows, handle, 
availableOffHeapPerTask())
         dep.metrics("splitTime").add(System.nanoTime() - startTime)
@@ -80,8 +135,8 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
       }
     }
 
-    assert(nativeShuffleWriter != -1L)
     val startTime = System.nanoTime()
+    assert(nativeShuffleWriter != -1L)
     splitResult = jniWrapper.stop(nativeShuffleWriter)
 
     dep
@@ -100,38 +155,6 @@ class VeloxCelebornColumnarShuffleWriter[K, V](
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, 
mapId)
   }
 
-  override def createShuffleWriter(columnarBatch: ColumnarBatch): Unit = {
-    nativeShuffleWriter = jniWrapper.makeForRSS(
-      dep.nativePartitioning,
-      nativeBufferSize,
-      customizedCompressionCodec,
-      compressionLevel,
-      bufferCompressThreshold,
-      GlutenConfig.getConf.columnarShuffleCompressionMode,
-      clientPushBufferMaxSize,
-      clientPushSortMemoryThreshold,
-      celebornPartitionPusher,
-      ColumnarBatches.getNativeHandle(columnarBatch),
-      context.taskAttemptId(),
-      GlutenShuffleUtils.getStartPartitionId(dep.nativePartitioning, 
context.partitionId),
-      "celeborn",
-      shuffleWriterType,
-      GlutenConfig.getConf.columnarShuffleReallocThreshold
-    )
-    runtime.addSpiller(new Spiller() {
-      override def spill(self: MemoryTarget, phase: Spiller.Phase, size: 
Long): Long = {
-        if (!Spillers.PHASE_SET_SPILL_ONLY.contains(phase)) {
-          return 0L
-        }
-        logInfo(s"Gluten shuffle writer: Trying to push $size bytes of data")
-        // fixme pass true when being called by self
-        val pushed = jniWrapper.nativeEvict(nativeShuffleWriter, size, false)
-        logInfo(s"Gluten shuffle writer: Pushed $pushed / $size bytes of data")
-        pushed
-      }
-    })
-  }
-
   override def closeShuffleWriter(): Unit = {
     jniWrapper.close(nativeShuffleWriter)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to