This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new f7c6d2da2 [#1751] improvement: support gluten again (#1857)
f7c6d2da2 is described below

commit f7c6d2da237bd487d3cd0e21231108df90559cbe
Author: xianjingfeng <xianjingfeng...@gmail.com>
AuthorDate: Thu Jul 4 10:24:59 2024 +0800

    [#1751] improvement: support gluten again (#1857)
    
    ### What changes were proposed in this pull request?
    support gluten
    
    ### Why are the changes needed?
    Currently, gluten will fail to compile using client from the master branch 
of uniffle.
    Fix: #1751
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Existing UTs and manual testing
---
 .../spark/shuffle/writer/WriteBufferManager.java   |  5 ++++
 .../shuffle/manager/RssShuffleManagerBase.java     | 16 +++++++++++
 .../apache/spark/shuffle/RssShuffleManager.java    | 15 +----------
 .../spark/shuffle/writer/RssShuffleWriter.java     |  7 ++---
 .../apache/spark/shuffle/RssShuffleManager.java    | 31 +++-------------------
 .../spark/shuffle/writer/RssShuffleWriter.java     | 18 +++++++++----
 6 files changed, 42 insertions(+), 50 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 95add5048..bfd929777 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -324,6 +324,11 @@ public class WriteBufferManager extends MemoryConsumer {
     return shuffleBlockInfos;
   }
 
+  // Gluten needs this method.
+  public synchronized List<ShuffleBlockInfo> clear() {
+    return clear(bufferSpillRatio);
+  }
+
   // transform all [partition, records] to [partition, ShuffleBlockInfo] and 
clear cache
   public synchronized List<ShuffleBlockInfo> clear(double bufferSpillRatio) {
     List<ShuffleBlockInfo> result = Lists.newArrayList();
diff --git 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
index 209ede25c..bbeec90dd 100644
--- 
a/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
+++ 
b/client-spark/common/src/main/java/org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.java
@@ -44,6 +44,7 @@ import org.apache.spark.MapOutputTrackerMaster;
 import org.apache.spark.SparkConf;
 import org.apache.spark.SparkEnv;
 import org.apache.spark.SparkException;
+import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.apache.spark.shuffle.RssSparkShuffleUtils;
 import org.apache.spark.shuffle.RssStageInfo;
@@ -53,6 +54,7 @@ import org.apache.spark.shuffle.ShuffleManager;
 import org.apache.spark.shuffle.SparkVersionUtils;
 import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
+import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
 import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -576,6 +578,20 @@ public abstract class RssShuffleManagerBase implements 
RssShuffleManagerInterfac
         sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), 
confItems);
   }
 
+  public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> 
rssHandle) {
+    int shuffleId = rssHandle.getShuffleId();
+    if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
+      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      return getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
+    } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
+      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
+      return getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
+    } else {
+      return new SimpleShuffleHandleInfo(
+          shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
+    }
+  }
+
   /**
    * In Stage Retry mode, obtain the Shuffle Server list from the Driver based 
on shuffleId.
    *
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index de5d4da63..1e5bb4941 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -424,18 +424,6 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
 
       int shuffleId = rssHandle.getShuffleId();
       String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
-      ShuffleHandleInfo shuffleHandleInfo;
-      if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
-        // In Stage Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId
-        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
-      } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
-        // In Block Retry mode, Get the ShuffleServer list from the Driver 
based on the shuffleId
-        shuffleHandleInfo = 
getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
-      } else {
-        shuffleHandleInfo =
-            new SimpleShuffleHandleInfo(
-                shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
-      }
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
       return new RssShuffleWriter<>(
           rssHandle.getAppId(),
@@ -448,8 +436,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
           shuffleWriteClient,
           rssHandle,
           this::markFailedTask,
-          context,
-          shuffleHandleInfo);
+          context);
     } else {
       throw new RssException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 65b66df3d..5ac6a7e9e 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -97,6 +97,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private String appId;
   private int numMaps;
   private int shuffleId;
+  private final ShuffleHandleInfo shuffleHandleInfo;
   private int bitmapSplitNum;
   private String taskId;
   private long taskAttemptId;
@@ -176,6 +177,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
   }
@@ -191,8 +193,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -204,7 +205,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 1d5050790..bf42bf361 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -512,18 +512,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     } else {
       writeMetrics = context.taskMetrics().shuffleWriteMetrics();
     }
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
-      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
-    } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
-      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new SimpleShuffleHandleInfo(
-              shuffleId, rssHandle.getPartitionToServers(), 
rssHandle.getRemoteStorage());
-    }
+
     String taskId = "" + context.taskAttemptId() + "_" + 
context.attemptNumber();
     return new RssShuffleWriter<>(
         rssHandle.getAppId(),
@@ -536,8 +525,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
         shuffleWriteClient,
         rssHandle,
         this::markFailedTask,
-        context,
-        shuffleHandleInfo);
+        context);
   }
 
   @Override
@@ -656,20 +644,7 @@ public class RssShuffleManager extends 
RssShuffleManagerBase {
     RssShuffleHandle<K, ?, C> rssShuffleHandle = (RssShuffleHandle<K, ?, C>) 
handle;
     final int partitionNum = 
rssShuffleHandle.getDependency().partitioner().numPartitions();
     int shuffleId = rssShuffleHandle.getShuffleId();
-    ShuffleHandleInfo shuffleHandleInfo;
-    if (shuffleManagerRpcServiceEnabled && rssStageRetryEnabled) {
-      // In Stage Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo = getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
-    } else if (shuffleManagerRpcServiceEnabled && partitionReassignEnabled) {
-      // In Block Retry mode, Get the ShuffleServer list from the Driver based 
on the shuffleId.
-      shuffleHandleInfo = getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
-    } else {
-      shuffleHandleInfo =
-          new SimpleShuffleHandleInfo(
-              shuffleId,
-              rssShuffleHandle.getPartitionToServers(),
-              rssShuffleHandle.getRemoteStorage());
-    }
+    ShuffleHandleInfo shuffleHandleInfo = 
getShuffleHandleInfo(rssShuffleHandle);
     Map<ShuffleServerInfo, Set<Integer>> serverToPartitions =
         getPartitionDataServers(shuffleHandleInfo, startPartition, 
endPartition);
     long start = System.currentTimeMillis();
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 50eb47001..6660a5e7b 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -104,6 +104,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private final String appId;
   private final int shuffleId;
+  private final ShuffleHandleInfo shuffleHandleInfo;
   private WriteBufferManager bufferManager;
   private String taskId;
   private final int numMaps;
@@ -119,7 +120,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
   private final long[] partitionLengths;
-  private final boolean isMemoryShuffleEnabled;
+  // Gluten needs this variable
+  protected final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
   private final Set<Long> blockIds = Sets.newConcurrentHashSet();
   private TaskContext taskContext;
@@ -211,6 +213,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
+    this.shuffleHandleInfo = shuffleHandleInfo;
     this.taskContext = context;
     this.sparkConf = sparkConf;
     this.blockFailSentRetryEnabled =
@@ -233,8 +236,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       ShuffleWriteClient shuffleWriteClient,
       RssShuffleHandle<K, V, C> rssHandle,
       Function<String, Boolean> taskFailureCallback,
-      TaskContext context,
-      ShuffleHandleInfo shuffleHandleInfo) {
+      TaskContext context) {
     this(
         appId,
         shuffleId,
@@ -246,7 +248,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         shuffleWriteClient,
         rssHandle,
         taskFailureCallback,
-        shuffleHandleInfo,
+        shuffleManager.getShuffleHandleInfo(rssHandle),
         context);
     BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
     final WriteBufferManager bufferManager =
@@ -288,7 +290,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
-  private void writeImpl(Iterator<Product2<K, V>> records) {
+  // Gluten needs this method.
+  protected void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
     boolean isCombine = shuffleDependency.mapSideCombine();
 
@@ -454,6 +457,11 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return futures;
   }
 
+  // Gluten needs this method
+  protected void internalCheckBlockSendResult() {
+    this.checkBlockSendResult(this.blockIds);
+  }
+
   @VisibleForTesting
   protected void checkBlockSendResult(Set<Long> blockIds) {
     boolean interrupted = false;

Reply via email to