This is an automated email from the ASF dual-hosted git repository.

zhifgli pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 01def93f Supports ZSTD (#254)
01def93f is described below

commit 01def93fff1a40a83676d07821754ba0d91d65f2
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 26 19:05:03 2022 +0800

    Supports ZSTD (#254)
    
    ### What changes were proposed in this pull request?
    1. Introduce the ZSTD compression
    2. Introduce the abstract interface of codec
    3. Recycle the buffer to optimize the performance
    
    
    ### Why are the changes needed?
    ZSTD has a good tradeoff between compression ratio and de/compress speed. 
For reducing the shuffle-data stored size, it's necessary to support this 
compression algorithm.
    
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    
    ### How was this patch tested?
    Manual tests and UTs
---
 client-mr/pom.xml                                  |  4 +
 .../hadoop/mapred/RssMapOutputCollector.java       |  3 +-
 .../hadoop/mapred/SortWriteBufferManager.java      | 12 ++-
 .../org/apache/hadoop/mapreduce/RssMRConfig.java   | 16 ++++
 .../hadoop/mapreduce/task/reduce/RssFetcher.java   | 27 ++++--
 .../hadoop/mapreduce/task/reduce/RssShuffle.java   |  2 +-
 .../hadoop/mapred/SortWriteBufferManagerTest.java  | 13 ++-
 .../hadoop/mapreduce/task/reduce/FetcherTest.java  | 26 ++++--
 .../org/apache/spark/shuffle/RssSparkConfig.java   | 16 ++++
 .../shuffle/reader/RssShuffleDataIterator.java     | 41 +++++----
 .../spark/shuffle/writer/WriteBufferManager.java   | 10 ++-
 .../shuffle/reader/AbstractRssReaderTest.java      |  5 +-
 .../shuffle/reader/RssShuffleDataIteratorTest.java |  6 +-
 .../shuffle/writer/WriteBufferManagerTest.java     |  3 +-
 .../apache/spark/shuffle/RssShuffleManager.java    | 14 ++-
 .../spark/shuffle/reader/RssShuffleReader.java     |  8 +-
 .../spark/shuffle/reader/RssShuffleReaderTest.java |  3 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 11 ++-
 .../apache/spark/shuffle/RssShuffleManager.java    |  5 +-
 .../spark/shuffle/reader/RssShuffleReader.java     |  8 +-
 .../spark/shuffle/reader/RssShuffleReaderTest.java |  8 +-
 .../spark/shuffle/writer/RssShuffleWriterTest.java | 27 ++----
 common/pom.xml                                     |  5 ++
 .../org/apache/uniffle/common/RssShuffleUtils.java | 36 --------
 .../apache/uniffle/common/compression/Codec.java   | 51 +++++++++++
 .../uniffle/common/compression/Lz4Codec.java       | 41 +++++++++
 .../uniffle/common/compression/NoOpCodec.java      | 35 ++++++++
 .../uniffle/common/compression/ZstdCodec.java      | 67 +++++++++++++++
 .../uniffle/common/config/RssClientConf.java       | 38 +++++++++
 .../apache/uniffle/common/RssShuffleUtilsTest.java | 99 ----------------------
 .../common/compression/CompressionTest.java        | 83 ++++++++++++++++++
 docs/client_guide.md                               |  2 +
 .../test/RepartitionWithLocalFileRssTest.java      | 38 +++++++++
 .../uniffle/test/SparkIntegrationTestBase.java     |  2 +-
 pom.xml                                            |  7 ++
 35 files changed, 549 insertions(+), 223 deletions(-)

diff --git a/client-mr/pom.xml b/client-mr/pom.xml
index 2c2332f4..b91c3160 100644
--- a/client-mr/pom.xml
+++ b/client-mr/pom.xml
@@ -105,6 +105,10 @@
             <artifactId>mockito-core</artifactId>
             <scope>test</scope>
         </dependency>
+        <dependency>
+            <groupId>com.github.luben</groupId>
+            <artifactId>zstd-jni</artifactId>
+        </dependency>
     </dependencies>
 
     <build>
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java 
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
index 308a560c..c9cb553f 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapred/RssMapOutputCollector.java
@@ -130,7 +130,8 @@ public class RssMapOutputCollector<K extends Object, V 
extends Object>
         isMemoryShuffleEnabled(storageType),
         sendThreadNum,
         sendThreshold,
-        maxBufferSize);
+        maxBufferSize,
+        RssMRConfig.toRssConf(rssJobConf));
   }
 
   private Map<Integer, List<ShuffleServerInfo>> createAssignmentMap(JobConf 
jobConf) {
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java 
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
index aa4da547..36ade47e 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapred/SortWriteBufferManager.java
@@ -43,9 +43,10 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
-import org.apache.uniffle.common.RssShuffleUtils;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
@@ -90,6 +91,8 @@ public class SortWriteBufferManager<K, V> {
   private long sortTime = 0;
   private final long maxBufferSize;
   private final ExecutorService sendExecutorService;
+  private final RssConf rssConf;
+  private final Codec codec;
 
   public SortWriteBufferManager(
       long maxMemSize,
@@ -114,7 +117,8 @@ public class SortWriteBufferManager<K, V> {
       boolean isMemoryShuffleEnabled,
       int sendThreadNum,
       double sendThreshold,
-      long maxBufferSize) {
+      long maxBufferSize,
+      RssConf rssConf) {
     this.maxMemSize = maxMemSize;
     this.taskAttemptId = taskAttemptId;
     this.batch = batch;
@@ -140,6 +144,8 @@ public class SortWriteBufferManager<K, V> {
     this.sendExecutorService  = Executors.newFixedThreadPool(
         sendThreadNum,
         ThreadUtils.getThreadFactory("send-thread-%d"));
+    this.rssConf = rssConf;
+    this.codec = Codec.newInstance(rssConf);
   }
 
   // todo: Single Buffer should also have its size limit
@@ -309,7 +315,7 @@ public class SortWriteBufferManager<K, V> {
     int partitionId = wb.getPartitionId();
     final int uncompressLength = data.length;
     long start = System.currentTimeMillis();
-    final byte[] compressed = RssShuffleUtils.compressData(data);
+    final byte[] compressed = codec.compress(data);
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     compressTime += System.currentTimeMillis() - start;
     final long blockId = RssMRUtils.getBlockId((long)partitionId, 
taskAttemptId, getNextSeqNo(partitionId));
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
index eb518162..d89b4f12 100644
--- a/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
+++ b/client-mr/src/main/java/org/apache/hadoop/mapreduce/RssMRConfig.java
@@ -17,11 +17,14 @@
 
 package org.apache.hadoop.mapreduce;
 
+import java.util.Map;
 import java.util.Set;
 
 import com.google.common.collect.ImmutableSet;
+import org.apache.hadoop.mapred.JobConf;
 
 import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.config.RssConf;
 
 public class RssMRConfig {
 
@@ -164,4 +167,17 @@ public class RssMRConfig {
 
   public static final Set<String> RSS_MANDATORY_CLUSTER_CONF =
       ImmutableSet.of(RSS_STORAGE_TYPE, RSS_REMOTE_STORAGE_PATH);
+
+  public static RssConf toRssConf(JobConf jobConf) {
+    RssConf rssConf = new RssConf();
+    for (Map.Entry<String, String> entry : jobConf) {
+      String key = entry.getKey();
+      if (!key.startsWith(MR_RSS_CONFIG_PREFIX)) {
+        continue;
+      }
+      key = key.substring(MR_RSS_CONFIG_PREFIX.length());
+      rssConf.setString(key, entry.getValue());
+    }
+    return rssConf;
+  }
 }
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
index 128bfb9f..8e0859cc 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssFetcher.java
@@ -35,7 +35,8 @@ import org.apache.hadoop.util.Progress;
 
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.response.CompressedShuffleBlock;
-import org.apache.uniffle.common.RssShuffleUtils;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.ByteUnit;
 
@@ -84,14 +85,17 @@ public class RssFetcher<K,V> {
   private long startWait;
   private int waitCount = 0;
   private byte[] uncompressedData = null;
+  private RssConf rssConf;
+  private Codec codec;
 
   RssFetcher(JobConf job, TaskAttemptID reduceId,
-             TaskStatus status,
-             MergeManager<K,V> merger,
-             Progress progress,
-             Reporter reporter, ShuffleClientMetrics metrics,
-             ShuffleReadClient shuffleReadClient,
-             long totalBlockCount) {
+      TaskStatus status,
+      MergeManager<K, V> merger,
+      Progress progress,
+      Reporter reporter, ShuffleClientMetrics metrics,
+      ShuffleReadClient shuffleReadClient,
+      long totalBlockCount,
+      RssConf rssConf) {
     this.jobConf = job;
     this.reporter = reporter;
     this.status = status;
@@ -114,6 +118,9 @@ public class RssFetcher<K,V> {
 
     this.shuffleReadClient = shuffleReadClient;
     this.totalBlockCount = totalBlockCount;
+
+    this.rssConf = rssConf;
+    this.codec = Codec.newInstance(rssConf);
   }
 
   public void fetchAllRssBlocks() throws IOException, InterruptedException {
@@ -150,8 +157,10 @@ public class RssFetcher<K,V> {
     // uncompress the block
     if (!hasPendingData && compressedData != null) {
       final long startDecompress = System.currentTimeMillis();
-      uncompressedData = RssShuffleUtils.decompressData(
-          compressedData, compressedBlock.getUncompressLength(), 
false).array();
+      int uncompressedLen = compressedBlock.getUncompressLength();
+      ByteBuffer decompressedBuffer = ByteBuffer.allocate(uncompressedLen);
+      codec.decompress(compressedData, uncompressedLen, decompressedBuffer, 0);
+      uncompressedData = decompressedBuffer.array();
       unCompressionLength += compressedBlock.getUncompressLength();
       long decompressDuration = System.currentTimeMillis() - startDecompress;
       decompressTime += decompressDuration;
diff --git 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
index 1d30df96..e5af9795 100644
--- 
a/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
+++ 
b/client-mr/src/main/java/org/apache/hadoop/mapreduce/task/reduce/RssShuffle.java
@@ -197,7 +197,7 @@ public class RssShuffle<K, V> implements 
ShuffleConsumerPlugin<K, V>, ExceptionR
           readerJobConf, new MRIdHelper());
       ShuffleReadClient shuffleReadClient = 
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
       RssFetcher fetcher = new RssFetcher(mrJobConf, reduceId, taskStatus, 
merger, copyPhase, reporter, metrics,
-          shuffleReadClient, blockIdBitmap.getLongCardinality());
+          shuffleReadClient, blockIdBitmap.getLongCardinality(), 
RssMRConfig.toRssConf(rssJobConf));
       fetcher.fetchAllRssBlocks();
       LOG.info("In reduce: " + reduceId
           + ", Rss MR client fetches blocks from RSS server successfully");
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
index 029b1e0e..305a9dcb 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapred/SortWriteBufferManagerTest.java
@@ -38,6 +38,7 @@ import org.apache.uniffle.common.RemoteStorageInfo;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -79,7 +80,8 @@ public class SortWriteBufferManagerTest {
         true,
         5,
         0.2f,
-        1024000L);
+        1024000L,
+        new RssConf());
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
@@ -128,7 +130,8 @@ public class SortWriteBufferManagerTest {
         true,
         5,
         0.2f,
-        1024000L);
+        1024000L,
+        new RssConf());
     byte[] key = new byte[20];
     byte[] value = new byte[1024];
     random.nextBytes(key);
@@ -176,7 +179,8 @@ public class SortWriteBufferManagerTest {
         true,
         5,
         0.2f,
-        100L);
+        100L,
+        new RssConf());
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
@@ -223,7 +227,8 @@ public class SortWriteBufferManagerTest {
         true,
         5,
         0.2f,
-        1024000L);
+        1024000L,
+        new RssConf());
     Random random = new Random();
     for (int i = 0; i < 1000; i++) {
       byte[] key = new byte[20];
diff --git 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
index ec630e24..b5404e59 100644
--- 
a/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
+++ 
b/client-mr/src/test/java/org/apache/hadoop/mapreduce/task/reduce/FetcherTest.java
@@ -65,10 +65,12 @@ import 
org.apache.uniffle.client.response.CompressedShuffleBlock;
 import org.apache.uniffle.client.response.SendShuffleDataResult;
 import org.apache.uniffle.common.PartitionRange;
 import org.apache.uniffle.common.RemoteStorageInfo;
-import org.apache.uniffle.common.RssShuffleUtils;
 import org.apache.uniffle.common.ShuffleAssignmentsInfo;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.compression.Lz4Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -88,6 +90,8 @@ public class FetcherTest {
   static List<byte[]> data;
   static MergeManagerImpl<Text, Text> merger;
 
+  static Codec codec = new Lz4Codec();
+
   @Test
   public void writeAndReadDataTestWithRss() throws Throwable {
     fs = FileSystem.getLocal(conf);
@@ -97,7 +101,7 @@ public class FetcherTest {
         null, null, new Progress(), new MROutputFiles());
     ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
     RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, 
merger, new Progress(),
-        reporter, metrics, shuffleReadClient, 3);
+        reporter, metrics, shuffleReadClient, 3, new RssConf());
     fetcher.fetchAllRssBlocks();
 
 
@@ -128,7 +132,7 @@ public class FetcherTest {
         null, null, new Progress(), new MROutputFiles());
     ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
     RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, 
merger, new Progress(),
-        reporter, metrics, shuffleReadClient, 3);
+        reporter, metrics, shuffleReadClient, 3, new RssConf());
     fetcher.fetchAllRssBlocks();
 
 
@@ -161,7 +165,7 @@ public class FetcherTest {
       null, null, new Progress(), new MROutputFiles(), expectedFails);
     ShuffleReadClient shuffleReadClient = new MockedShuffleReadClient(data);
     RssFetcher fetcher = new RssFetcher(jobConf, reduceId1, taskStatus, 
merger, new Progress(),
-        reporter, metrics, shuffleReadClient, 3);
+        reporter, metrics, shuffleReadClient, 3, new RssConf());
     fetcher.fetchAllRssBlocks();
 
     RawKeyValueIterator iterator = merger.close();
@@ -276,7 +280,8 @@ public class FetcherTest {
         true,
         5,
         0.2f,
-        1024000L);
+        1024000L,
+        new RssConf());
 
     for (String key : keysToValues.keySet()) {
       String value = keysToValues.get(key);
@@ -357,7 +362,14 @@ public class FetcherTest {
           successBlockIds.add(blockInfo.getBlockId());
         }
         shuffleBlockInfoList.forEach(block -> {
-          data.add(RssShuffleUtils.decompressData(block.getData(), 
block.getUncompressLength()));
+          ByteBuffer uncompressedBuffer = 
ByteBuffer.allocate(block.getUncompressLength());
+          codec.decompress(
+              ByteBuffer.wrap(block.getData()),
+              block.getUncompressLength(),
+              uncompressedBuffer,
+              0
+          );
+          data.add(uncompressedBuffer.array());
         });
         return new SendShuffleDataResult(successBlockIds, Sets.newHashSet());
       }
@@ -440,7 +452,7 @@ public class FetcherTest {
     MockedShuffleReadClient(List<byte[]> data) {
       this.blocks = new LinkedList<>();
       data.forEach(bytes -> {
-        byte[] compressed = RssShuffleUtils.compressData(bytes);
+        byte[] compressed = codec.compress(bytes);
         blocks.add(new CompressedShuffleBlock(ByteBuffer.wrap(compressed), 
bytes.length));
       });
     }
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 5f39eb5d..71b4c283 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -20,13 +20,16 @@ package org.apache.spark.shuffle;
 import java.util.Set;
 
 import com.google.common.collect.ImmutableSet;
+import org.apache.spark.SparkConf;
 import org.apache.spark.internal.config.ConfigBuilder;
 import org.apache.spark.internal.config.ConfigEntry;
 import org.apache.spark.internal.config.TypedConfigBuilder;
+import scala.Tuple2;
 import scala.runtime.AbstractFunction1;
 
 import org.apache.uniffle.client.util.RssClientConfig;
 import org.apache.uniffle.common.config.ConfigUtils;
+import org.apache.uniffle.common.config.RssConf;
 
 public class RssSparkConfig {
 
@@ -286,4 +289,17 @@ public class RssSparkConfig {
   public static TypedConfigBuilder<String> createStringBuilder(ConfigBuilder 
builder) {
     return builder.stringConf();
   }
+
+  public static RssConf toRssConf(SparkConf sparkConf) {
+    RssConf rssConf = new RssConf();
+    for (Tuple2<String, String> tuple : sparkConf.getAll()) {
+      String key = tuple._1;
+      if (!key.startsWith(SPARK_RSS_CONFIG_PREFIX)) {
+        continue;
+      }
+      key = key.substring(SPARK_RSS_CONFIG_PREFIX.length());
+      rssConf.setString(key, tuple._2);
+    }
+    return rssConf;
+  }
 }
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 23e03641..7ba3e066 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -28,6 +28,7 @@ import org.apache.spark.executor.ShuffleReadMetrics;
 import org.apache.spark.serializer.DeserializationStream;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.RssSparkConfig;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import scala.Product2;
@@ -38,8 +39,9 @@ import scala.runtime.BoxedUnit;
 
 import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.response.CompressedShuffleBlock;
-import org.apache.uniffle.common.RssShuffleUtils;
-import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.client.util.RssClientConfig;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 
 public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, 
C>> {
 
@@ -57,19 +59,29 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
   private ByteBufInputStream byteBufInputStream = null;
   private long unCompressionLength = 0;
   private ByteBuffer uncompressedData;
+  private Codec codec;
 
   public RssShuffleDataIterator(
       Serializer serializer,
       ShuffleReadClient shuffleReadClient,
-      ShuffleReadMetrics shuffleReadMetrics) {
+      ShuffleReadMetrics shuffleReadMetrics,
+      RssConf rssConf) {
     this.serializerInstance = serializer.newInstance();
     this.shuffleReadClient = shuffleReadClient;
     this.shuffleReadMetrics = shuffleReadMetrics;
+    this.codec = Codec.newInstance(rssConf);
+    // todo: support off-heap bytebuffer
+    this.uncompressedData = ByteBuffer.allocate(
+        (int) rssConf.getSizeAsBytes(
+            RssClientConfig.RSS_WRITER_BUFFER_SIZE,
+            RssSparkConfig.RSS_WRITER_BUFFER_SIZE.defaultValueString()
+        )
+    );
   }
 
-  public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data) {
+  public Iterator<Tuple2<Object, Object>> createKVIterator(ByteBuffer data, 
int size) {
     clearDeserializationStream();
-    byteBufInputStream = new ByteBufInputStream(Unpooled.wrappedBuffer(data), 
true);
+    byteBufInputStream = new 
ByteBufInputStream(Unpooled.wrappedBuffer(data.array(), 0, size), true);
     deserializationStream = 
serializerInstance.deserializeStream(byteBufInputStream);
     return deserializationStream.asKeyValueIterator();
   }
@@ -109,24 +121,20 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       shuffleReadMetrics.incFetchWaitTime(fetchDuration);
       if (compressedData != null) {
         shuffleReadMetrics.incRemoteBytesRead(compressedData.limit() - 
compressedData.position());
-        // Directbytebuffers are not collected in time will cause executor 
easy 
-        // be killed by cluster managers(such as YARN) for using too much 
offheap memory
-        if (uncompressedData != null && uncompressedData.isDirect()) {
-          try {
-            RssShuffleUtils.destroyDirectByteBuffer(uncompressedData);
-          } catch (Exception e) {
-            throw new RssException("Destroy DirectByteBuffer failed!", e);
-          }
+
+        int uncompressedLen = compressedBlock.getUncompressLength();
+        if (uncompressedData == null || uncompressedData.capacity() < 
uncompressedLen) {
+          uncompressedData = ByteBuffer.allocate(uncompressedLen);
         }
+        uncompressedData.clear();
         long startDecompress = System.currentTimeMillis();
-        uncompressedData = RssShuffleUtils.decompressData(
-            compressedData, compressedBlock.getUncompressLength());
+        codec.decompress(compressedData, uncompressedLen, uncompressedData, 0);
         unCompressionLength += compressedBlock.getUncompressLength();
         long decompressDuration = System.currentTimeMillis() - startDecompress;
         decompressTime += decompressDuration;
         // create new iterator for shuffle data
         long startSerialization = System.currentTimeMillis();
-        recordsIterator = createKVIterator(uncompressedData);
+        recordsIterator = createKVIterator(uncompressedData, uncompressedLen);
         long serializationDuration = System.currentTimeMillis() - 
startSerialization;
         readTime += fetchDuration;
         serializeTime += serializationDuration;
@@ -155,6 +163,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
       shuffleReadClient.close();
     }
     shuffleReadClient = null;
+    uncompressedData = null;
     return BoxedUnit.UNIT;
   }
 
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index ffb6000b..5c10ac67 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -37,9 +37,10 @@ import org.slf4j.LoggerFactory;
 import scala.reflect.ClassTag$;
 
 import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.common.RssShuffleUtils;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.ChecksumUtils;
 
@@ -77,6 +78,7 @@ public class WriteBufferManager extends MemoryConsumer {
   private long uncompressedDataLen = 0;
   private long requireMemoryInterval;
   private int requireMemoryRetryMax;
+  private Codec codec;
 
   public WriteBufferManager(
       int shuffleId,
@@ -85,7 +87,8 @@ public class WriteBufferManager extends MemoryConsumer {
       Serializer serializer,
       Map<Integer, List<ShuffleServerInfo>> partitionToServers,
       TaskMemoryManager taskMemoryManager,
-      ShuffleWriteMetrics shuffleWriteMetrics) {
+      ShuffleWriteMetrics shuffleWriteMetrics,
+      RssConf rssConf) {
     super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), 
MemoryMode.ON_HEAP);
     this.bufferSize = bufferManagerOptions.getBufferSize();
     this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
@@ -102,6 +105,7 @@ public class WriteBufferManager extends MemoryConsumer {
     this.requireMemoryRetryMax = 
bufferManagerOptions.getRequireMemoryRetryMax();
     this.arrayOutputStream = new 
WrappedByteArrayOutputStream(serializerBufferSize);
     this.serializeStream = instance.serializeStream(arrayOutputStream);
+    this.codec = Codec.newInstance(rssConf);
   }
 
   public List<ShuffleBlockInfo> addRecord(int partitionId, Object key, Object 
value) {
@@ -170,7 +174,7 @@ public class WriteBufferManager extends MemoryConsumer {
     byte[] data = wb.getData();
     final int uncompressLength = data.length;
     long start = System.currentTimeMillis();
-    final byte[] compressed = RssShuffleUtils.compressData(data);
+    final byte[] compressed = codec.compress(data);
     final long crc32 = ChecksumUtils.getCrc32(compressed);
     compressTime += System.currentTimeMillis() - start;
     final long blockId = ClientUtils.getBlockId(partitionId, taskAttemptId, 
getNextSeqNo(partitionId));
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
index 422c6d04..fd290cb4 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/AbstractRssReaderTest.java
@@ -34,8 +34,9 @@ import scala.collection.Iterator;
 import scala.reflect.ClassTag$;
 
 import org.apache.uniffle.client.util.ClientUtils;
-import org.apache.uniffle.common.RssShuffleUtils;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
+import org.apache.uniffle.common.compression.Codec;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.storage.HdfsTestBase;
 import org.apache.uniffle.storage.handler.api.ShuffleWriteHandler;
@@ -90,7 +91,7 @@ public abstract class AbstractRssReaderTest extends 
HdfsTestBase {
   }
 
   protected ShufflePartitionedBlock createShuffleBlock(byte[] data, long 
blockId) {
-    byte[] compressData = RssShuffleUtils.compressData(data);
+    byte[] compressData = Codec.newInstance(new RssConf()).compress(data);
     long crc = ChecksumUtils.getCrc32(compressData);
     return new ShufflePartitionedBlock(compressData.length, data.length, crc, 
blockId, 0,
         compressData);
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
index f4f55c18..78c3375b 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/reader/RssShuffleDataIteratorTest.java
@@ -38,6 +38,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.impl.ShuffleReadClientImpl;
 import org.apache.uniffle.client.util.ClientUtils;
 import org.apache.uniffle.client.util.DefaultIdHelper;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
@@ -96,7 +97,7 @@ public class RssShuffleDataIteratorTest extends 
AbstractRssReaderTest {
         10, 10000, basePath, blockIdBitmap, taskIdBitmap, Lists.newArrayList(),
         new Configuration(), new DefaultIdHelper());
     return new RssShuffleDataIterator(KRYO_SERIALIZER, readClient,
-        new ShuffleReadMetrics());
+        new ShuffleReadMetrics(), new RssConf());
   }
 
   @Test
@@ -119,7 +120,6 @@ public class RssShuffleDataIteratorTest extends 
AbstractRssReaderTest {
 
     validateResult(rssShuffleDataIterator, expectedData, 20);
     assertEquals(20, 
rssShuffleDataIterator.getShuffleReadMetrics().recordsRead());
-    assertEquals(256, 
rssShuffleDataIterator.getShuffleReadMetrics().remoteBytesRead());
     assertTrue(rssShuffleDataIterator.getShuffleReadMetrics().fetchWaitTime() 
> 0);
   }
 
@@ -250,7 +250,7 @@ public class RssShuffleDataIteratorTest extends 
AbstractRssReaderTest {
     ShuffleReadClient mockClient = mock(ShuffleReadClient.class);
     doNothing().when(mockClient).close();
     RssShuffleDataIterator dataIterator =
-        new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new 
ShuffleReadMetrics());
+        new RssShuffleDataIterator(KRYO_SERIALIZER, mockClient, new 
ShuffleReadMetrics(), new RssConf());
     dataIterator.cleanup();
     verify(mockClient, times(1)).close();
   }
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 665f5d2d..3fc20398 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -29,6 +29,7 @@ import org.apache.spark.shuffle.RssSparkConfig;
 import org.junit.jupiter.api.Test;
 
 import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.config.RssConf;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -47,7 +48,7 @@ public class WriteBufferManagerTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager wbm = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
new RssConf());
     WriteBufferManager spyManager = spy(wbm);
     doReturn(512L).when(spyManager).acquireMemory(anyLong());
     return spyManager;
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 8f076040..26022a54 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -305,9 +305,15 @@ public class RssShuffleManager implements ShuffleManager {
       BufferManagerOptions bufferOptions = new BufferManagerOptions(sparkConf);
       ShuffleWriteMetrics writeMetrics = 
context.taskMetrics().shuffleWriteMetrics();
       WriteBufferManager bufferManager = new WriteBufferManager(
-          shuffleId, context.taskAttemptId(), bufferOptions, 
rssHandle.getDependency().serializer(),
-          rssHandle.getPartitionToServers(), context.taskMemoryManager(),
-          writeMetrics);
+          shuffleId,
+          context.taskAttemptId(),
+          bufferOptions,
+          rssHandle.getDependency().serializer(),
+          rssHandle.getPartitionToServers(),
+          context.taskMemoryManager(),
+          writeMetrics,
+          RssSparkConfig.toRssConf(sparkConf)
+      );
       taskToBufferManager.put(taskId, bufferManager);
 
       return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
@@ -360,7 +366,7 @@ public class RssShuffleManager implements ShuffleManager {
           rssShuffleHandle, shuffleRemoteStoragePath, indexReadLimit,
           readerHadoopConf,
           storageType, (int) readBufferSize, partitionNumPerRange, 
partitionNum,
-          blockIdBitmap, taskIdBitmap);
+          blockIdBitmap, taskIdBitmap, RssSparkConfig.toRssConf(sparkConf));
     } else {
       throw new RuntimeException("Unexpected ShuffleHandle:" + 
handle.getClass().getName());
     }
diff --git 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ef97bea3..a32ba226 100644
--- 
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -44,6 +44,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
 
 public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
 
@@ -67,6 +68,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private Roaring64NavigableMap taskIdBitmap;
   private List<ShuffleServerInfo> shuffleServerInfoList;
   private Configuration hadoopConf;
+  private RssConf rssConf;
 
   public RssShuffleReader(
       int startPartition,
@@ -81,7 +83,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       int partitionNumPerRange,
       int partitionNum,
       Roaring64NavigableMap blockIdBitmap,
-      Roaring64NavigableMap taskIdBitmap) {
+      Roaring64NavigableMap taskIdBitmap,
+      RssConf rssConf) {
     this.appId = rssShuffleHandle.getAppId();
     this.startPartition = startPartition;
     this.endPartition = endPartition;
@@ -101,6 +104,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.hadoopConf = hadoopConf;
     this.shuffleServerInfoList =
         (List<ShuffleServerInfo>) 
(rssShuffleHandle.getPartitionToServers().get(startPartition));
+    this.rssConf = rssConf;
   }
 
   @Override
@@ -113,7 +117,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     ShuffleReadClient shuffleReadClient = 
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
     RssShuffleDataIterator rssShuffleDataIterator = new 
RssShuffleDataIterator<K, C>(
         shuffleDependency.serializer(), shuffleReadClient,
-        context.taskMetrics().shuffleReadMetrics());
+        context.taskMetrics().shuffleReadMetrics(), rssConf);
     CompletionIterator completionIterator =
         CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new 
AbstractFunction0<BoxedUnit>() {
           @Override
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 473ce609..ce33f47c 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -31,6 +31,7 @@ import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import scala.Option;
 
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -73,7 +74,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
 
     RssShuffleReader rssShuffleReaderSpy = spy(new RssShuffleReader<String, 
String>(0, 1, contextMock,
         handleMock, basePath, 1000, conf, StorageType.HDFS.name(),
-        1000, 2, 10, blockIdBitmap, taskIdBitmap));
+        1000, 2, 10, blockIdBitmap, taskIdBitmap, new RssConf()));
 
     validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
   }
diff --git 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 084a731c..f71900ce 100644
--- 
a/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark2/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -47,6 +47,7 @@ import scala.collection.mutable.MutableList;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -91,7 +92,7 @@ public class RssShuffleWriterTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
new RssConf());
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
 
@@ -197,7 +198,7 @@ public class RssShuffleWriterTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics);
+        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new 
RssConf());
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     doReturn(1000000L).when(bufferManagerSpy).acquireMemory(anyLong());
 
@@ -219,12 +220,14 @@ public class RssShuffleWriterTest {
 
     assertTrue(rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleWriteTime() 
> 0);
     assertEquals(6, 
rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleRecordsWritten());
-    assertEquals(144, 
rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleBytesWritten());
+    assertEquals(
+        shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+        rssShuffleWriterSpy.getShuffleWriteMetrics().shuffleBytesWritten()
+    );
 
     assertEquals(6, shuffleBlockInfos.size());
     for (ShuffleBlockInfo shuffleBlockInfo : shuffleBlockInfos) {
       assertEquals(0, shuffleBlockInfo.getShuffleId());
-      assertEquals(24, shuffleBlockInfo.getLength());
       assertEquals(22, shuffleBlockInfo.getUncompressLength());
       if (shuffleBlockInfo.getPartitionId() == 0) {
         assertEquals(shuffleBlockInfo.getShuffleServerInfos(), ssi12);
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 41c2f4d7..ea29a4cd 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -331,7 +331,7 @@ public class RssShuffleManager implements ShuffleManager {
     WriteBufferManager bufferManager = new WriteBufferManager(
         shuffleId, context.taskAttemptId(), bufferOptions, 
rssHandle.getDependency().serializer(),
         rssHandle.getPartitionToServers(), context.taskMemoryManager(),
-        writeMetrics);
+        writeMetrics, RssSparkConfig.toRssConf(sparkConf));
     taskToBufferManager.put(taskId, bufferManager);
     LOG.info("RssHandle appId {} shuffleId {} ", rssHandle.getAppId(), 
rssHandle.getShuffleId());
     return new RssShuffleWriter(rssHandle.getAppId(), shuffleId, taskId, 
context.taskAttemptId(), bufferManager,
@@ -459,7 +459,8 @@ public class RssShuffleManager implements ShuffleManager {
         partitionNum,
         RssUtils.generatePartitionToBitmap(blockIdBitmap, startPartition, 
endPartition),
         taskIdBitmap,
-        readMetrics);
+        readMetrics,
+        RssSparkConfig.toRssConf(sparkConf));
   }
 
   private Roaring64NavigableMap getExpectedTasksByExecutorId(
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index a565cfe4..2806ce82 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -50,6 +50,7 @@ import org.apache.uniffle.client.api.ShuffleReadClient;
 import org.apache.uniffle.client.factory.ShuffleClientFactory;
 import org.apache.uniffle.client.request.CreateShuffleReadClientRequest;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
 
 public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleReader.class);
@@ -74,6 +75,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
   private int mapStartIndex;
   private int mapEndIndex;
   private ShuffleReadMetrics readMetrics;
+  private RssConf rssConf;
 
   public RssShuffleReader(
       int startPartition,
@@ -90,7 +92,8 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
       int partitionNum,
       Map<Integer, Roaring64NavigableMap> partitionToExpectBlocks,
       Roaring64NavigableMap taskIdBitmap,
-      ShuffleReadMetrics readMetrics) {
+      ShuffleReadMetrics readMetrics,
+      RssConf rssConf) {
     this.appId = rssShuffleHandle.getAppId();
     this.startPartition = startPartition;
     this.endPartition = endPartition;
@@ -111,6 +114,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
     this.hadoopConf = hadoopConf;
     this.readMetrics = readMetrics;
     this.partitionToShuffleServers = rssShuffleHandle.getPartitionToServers();
+    this.rssConf = rssConf;
   }
 
   @Override
@@ -201,7 +205,7 @@ public class RssShuffleReader<K, C> implements 
ShuffleReader<K, C> {
         ShuffleReadClient shuffleReadClient = 
ShuffleClientFactory.getInstance().createShuffleReadClient(request);
         RssShuffleDataIterator iterator = new RssShuffleDataIterator<K, C>(
             shuffleDependency.serializer(), shuffleReadClient,
-            readMetrics);
+            readMetrics, rssConf);
         CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> 
completionIterator =
             CompletionIterator$.MODULE$.apply(iterator, () -> 
iterator.cleanup());
         iterators.add(completionIterator);
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
index 70938c88..5f8eceeb 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/reader/RssShuffleReaderTest.java
@@ -32,6 +32,7 @@ import org.junit.jupiter.api.Test;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import scala.Option;
 
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.storage.handler.impl.HdfsShuffleWriteHandler;
 import org.apache.uniffle.storage.util.StorageType;
 
@@ -93,7 +94,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
         1,
         partitionToExpectBlocks,
         taskIdBitmap,
-        new ShuffleReadMetrics()));
+        new ShuffleReadMetrics(), new RssConf()));
     validateResult(rssShuffleReaderSpy.read(), expectedData, 10);
 
     writeTestData(writeHandler1, 2, 4, expectedData,
@@ -114,7 +115,8 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
         2,
         partitionToExpectBlocks,
         taskIdBitmap,
-        new ShuffleReadMetrics()));
+        new ShuffleReadMetrics(), new RssConf())
+    );
     validateResult(rssShuffleReaderSpy1.read(), expectedData, 18);
 
     RssShuffleReader rssShuffleReaderSpy2 = spy(new RssShuffleReader<String, 
String>(
@@ -132,7 +134,7 @@ public class RssShuffleReaderTest extends 
AbstractRssReaderTest {
         2,
         partitionToExpectBlocks,
         Roaring64NavigableMap.bitmapOf(),
-        new ShuffleReadMetrics()));
+        new ShuffleReadMetrics(), new RssConf()));
     validateResult(rssShuffleReaderSpy2.read(), Maps.newHashMap(), 0);
   }
 
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
index 1b7afcd9..98ffc8a6 100644
--- 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/RssShuffleWriterTest.java
@@ -49,6 +49,7 @@ import scala.collection.mutable.MutableList;
 import org.apache.uniffle.client.api.ShuffleWriteClient;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
+import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.storage.util.StorageType;
 
 import static org.junit.jupiter.api.Assertions.assertEquals;
@@ -98,7 +99,7 @@ public class RssShuffleWriterTest {
     BufferManagerOptions bufferOptions = new BufferManagerOptions(conf);
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics());
+        Maps.newHashMap(), mockTaskMemoryManager, new ShuffleWriteMetrics(), 
new RssConf());
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
 
     RssShuffleWriter rssShuffleWriter = new RssShuffleWriter("appId", 0, 
"taskId", 1L,
@@ -206,7 +207,7 @@ public class RssShuffleWriterTest {
     ShuffleWriteMetrics shuffleWriteMetrics = new ShuffleWriteMetrics();
     WriteBufferManager bufferManager = new WriteBufferManager(
         0, 0, bufferOptions, kryoSerializer,
-        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics);
+        partitionToServers, mockTaskMemoryManager, shuffleWriteMetrics, new 
RssConf());
     WriteBufferManager bufferManagerSpy = spy(bufferManager);
     RssShuffleWriter rssShuffleWriter = new RssShuffleWriter("appId", 0, 
"taskId", 1L,
         bufferManagerSpy, shuffleWriteMetrics, manager, conf, 
mockShuffleWriteClient, mockHandle);
@@ -228,26 +229,14 @@ public class RssShuffleWriterTest {
 
     assertTrue(shuffleWriteMetrics.writeTime() > 0);
     assertEquals(6, shuffleWriteMetrics.recordsWritten());
-    // Spark3 and Spark2 use different version lz4, their length is different
-    // it can happen that 2 different platforms compress the same data 
differently,
-    // yet the decoded outcome remains identical to original.
-    // https://github.com/lz4/lz4/issues/812
-    if (TestUtils.isMacOnAppleSilicon()) {
-      assertEquals(144, shuffleWriteMetrics.bytesWritten());
-    } else {
-      assertEquals(120, shuffleWriteMetrics.bytesWritten());
-    }
+
+    assertEquals(
+        shuffleBlockInfos.stream().mapToInt(ShuffleBlockInfo::getLength).sum(),
+        shuffleWriteMetrics.bytesWritten()
+    );
 
     assertEquals(6, shuffleBlockInfos.size());
     for (ShuffleBlockInfo shuffleBlockInfo : shuffleBlockInfos) {
-      // it can happen that 2 different platforms compress the same data 
differently,
-      // yet the decoded outcome remains identical to original.
-      // https://github.com/lz4/lz4/issues/812
-      if (TestUtils.isMacOnAppleSilicon()) {
-        assertEquals(24, shuffleBlockInfo.getLength());
-      } else {
-        assertEquals(20, shuffleBlockInfo.getLength());
-      }
       assertEquals(22, shuffleBlockInfo.getUncompressLength());
       assertEquals(0, shuffleBlockInfo.getShuffleId());
       if (shuffleBlockInfo.getPartitionId() == 0) {
diff --git a/common/pom.xml b/common/pom.xml
index 20c5049b..f043eb9c 100644
--- a/common/pom.xml
+++ b/common/pom.xml
@@ -94,6 +94,11 @@
       <groupId>org.apache.hadoop</groupId>
       <artifactId>hadoop-minicluster</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.github.luben</groupId>
+      <artifactId>zstd-jni</artifactId>
+      <scope>provided</scope>
+    </dependency>
   </dependencies>
 
   <build>
diff --git 
a/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java 
b/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
index 58db058e..42788aa8 100644
--- a/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
+++ b/common/src/main/java/org/apache/uniffle/common/RssShuffleUtils.java
@@ -22,44 +22,8 @@ import java.lang.reflect.Method;
 import java.nio.ByteBuffer;
 
 import com.google.common.base.Preconditions;
-import net.jpountz.lz4.LZ4Compressor;
-import net.jpountz.lz4.LZ4Factory;
-import net.jpountz.lz4.LZ4FastDecompressor;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 public class RssShuffleUtils {
-
-  private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleUtils.class);
-
-  public static byte[] compressData(byte[] data) {
-    LZ4Compressor compressor = LZ4Factory.fastestInstance().fastCompressor();
-    return compressor.compress(data);
-  }
-
-  public static byte[] decompressData(byte[] data, int uncompressLength) {
-    LZ4FastDecompressor fastDecompressor = 
LZ4Factory.fastestInstance().fastDecompressor();
-    byte[] uncompressData = new byte[uncompressLength];
-    fastDecompressor.decompress(data, 0, uncompressData, 0, uncompressLength);
-    return uncompressData;
-  }
-
-  public static ByteBuffer decompressData(ByteBuffer data, int 
uncompressLength) {
-    return decompressData(data, uncompressLength, true);
-  }
-
-  public static ByteBuffer decompressData(ByteBuffer data, int 
uncompressLength, boolean useDirectMem) {
-    LZ4FastDecompressor fastDecompressor = 
LZ4Factory.fastestInstance().fastDecompressor();
-    ByteBuffer uncompressData;
-    if (useDirectMem) {
-      uncompressData = ByteBuffer.allocateDirect(uncompressLength);
-    } else {
-      uncompressData = ByteBuffer.allocate(uncompressLength);
-    }
-    fastDecompressor.decompress(data, data.position(), uncompressData, 0, 
uncompressLength);
-    return uncompressData;
-  }
-  
   /**
    * DirectByteBuffers are garbage collected by using a phantom reference and a
    * reference queue. Every once a while, the JVM checks the reference queue 
and
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/Codec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
new file mode 100644
index 00000000..9ff7d85d
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Codec.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+import static 
org.apache.uniffle.common.config.RssClientConf.ZSTD_COMPRESSION_LEVEL;
+
+public abstract class Codec {
+
+  public static Codec newInstance(RssConf rssConf) {
+    Type type = rssConf.get(COMPRESSION_TYPE);
+    switch (type) {
+      case ZSTD:
+        return new ZstdCodec(rssConf.get(ZSTD_COMPRESSION_LEVEL));
+      case NOOP:
+        return new NoOpCodec();
+      case LZ4:
+      default:
+        return new Lz4Codec();
+    }
+  }
+
+  public abstract void decompress(ByteBuffer src, int uncompressedLen, 
ByteBuffer dest, int destOffset);
+
+  public abstract byte[] compress(byte[] src);
+
+  public enum Type {
+    LZ4,
+    ZSTD,
+    NOOP,
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java
new file mode 100644
index 00000000..59b6df6f
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/Lz4Codec.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import net.jpountz.lz4.LZ4Factory;
+
+public class Lz4Codec extends Codec {
+
+  private LZ4Factory lz4Factory;
+
+  public Lz4Codec() {
+    this.lz4Factory = LZ4Factory.fastestInstance();
+  }
+
+  @Override
+  public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dest, 
int destOffset) {
+    lz4Factory.fastDecompressor().decompress(src, src.position(), dest, 
destOffset, uncompressedLen);
+  }
+
+  @Override
+  public byte[] compress(byte[] src) {
+    return lz4Factory.fastCompressor().compress(src);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java
new file mode 100644
index 00000000..99c7cb4e
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/NoOpCodec.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+public class NoOpCodec extends Codec {
+
+  @Override
+  public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dest, 
int destOffset) {
+    dest.put(src);
+  }
+
+  @Override
+  public byte[] compress(byte[] src) {
+    byte[] dst = new byte[src.length];
+    System.arraycopy(src, 0, dst, 0, src.length);
+    return dst;
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java 
b/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java
new file mode 100644
index 00000000..0c596af8
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/compression/ZstdCodec.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+
+import com.github.luben.zstd.Zstd;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.exception.RssException;
+
+public class ZstdCodec extends Codec {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(ZstdCodec.class);
+
+  private final int compressionLevel;
+
+  public ZstdCodec(int level) {
+    this.compressionLevel = level;
+    LOGGER.info("Initializing zstd compressor.");
+  }
+
+  @Override
+  public void decompress(ByteBuffer src, int uncompressedLen, ByteBuffer dst, 
int dstOffset) {
+    if (src.isDirect() && dst.isDirect()) {
+      long size = Zstd.decompressDirectByteBuffer(
+          dst, dstOffset, uncompressedLen,
+          src, src.position(), src.limit() - src.position()
+      );
+      if (size != uncompressedLen) {
+        throw new RssException(
+            "This should not happen that the decompressed data size is not 
equals to original size.");
+      }
+      return;
+    }
+
+    if (!src.isDirect() && !dst.isDirect()) {
+      Zstd.decompressByteArray(
+          dst.array(), dstOffset, uncompressedLen,
+          src.array(), src.position(), src.limit() - src.position()
+      );
+      return;
+    }
+
+    throw new IllegalStateException("Zstd only supports the same type of 
bytebuffer decompression.");
+  }
+
+  @Override
+  public byte[] compress(byte[] src) {
+    return Zstd.compress(src, compressionLevel);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
new file mode 100644
index 00000000..99d82e03
--- /dev/null
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.config;
+
+import org.apache.uniffle.common.compression.Codec;
+
+import static org.apache.uniffle.common.compression.Codec.Type.LZ4;
+
+public class RssClientConf {
+
+  public static final ConfigOption<Codec.Type> COMPRESSION_TYPE = ConfigOptions
+      .key("rss.client.io.compression.codec")
+      .enumType(Codec.Type.class)
+      .defaultValue(LZ4)
+      .withDescription("The compression codec is used to compress the shuffle 
data. "
+          + "Default codec is `LZ4`, `ZSTD` also can be used.");
+
+  public static final ConfigOption<Integer> ZSTD_COMPRESSION_LEVEL = 
ConfigOptions
+      .key("rss.client.io.compression.zstd.level")
+      .intType()
+      .defaultValue(3)
+      .withDescription("The zstd compression level, the default level is 3");
+}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java 
b/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java
deleted file mode 100644
index f6f00a17..00000000
--- a/common/src/test/java/org/apache/uniffle/common/RssShuffleUtilsTest.java
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.uniffle.common;
-
-import java.lang.reflect.Field;
-import java.nio.Buffer;
-import java.nio.ByteBuffer;
-
-import org.apache.commons.lang3.RandomUtils;
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.params.ParameterizedTest;
-import org.junit.jupiter.params.provider.ValueSource;
-import sun.misc.Unsafe;
-
-import static org.junit.jupiter.api.Assertions.assertArrayEquals;
-import static org.junit.jupiter.api.Assertions.assertFalse;
-
-public class RssShuffleUtilsTest {
-
-  @ParameterizedTest
-  @ValueSource(ints = {1, 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 4 * 1024 
* 1024})
-  public void testCompression(int size) {
-    byte[] data = RandomUtils.nextBytes(size);
-    byte[] compressed = RssShuffleUtils.compressData(data);
-    byte[] decompressed = RssShuffleUtils.decompressData(compressed, size);
-    assertArrayEquals(data, decompressed);
-
-    ByteBuffer decompressedBB = 
RssShuffleUtils.decompressData(ByteBuffer.wrap(compressed), size);
-    byte[] buffer = new byte[size];
-    decompressedBB.get(buffer);
-    assertArrayEquals(data, buffer);
-
-    ByteBuffer decompressedBB2 = 
RssShuffleUtils.decompressData(ByteBuffer.wrap(compressed), size, false);
-    byte[] buffer2 = new byte[size];
-    decompressedBB2.get(buffer2);
-    assertArrayEquals(data, buffer2);
-  }
-
-  @Test
-  public void testDestroyDirectByteBuffer() throws Exception {
-    int size = 10;
-    byte b = 1;
-    ByteBuffer byteBuffer = ByteBuffer.allocateDirect(size);
-    for (int i = 0; i < size; i++) {
-      byteBuffer.put(b);
-    }
-    byteBuffer.flip();
-
-    // Get valid native pointer through `address` in `DirectByteBuffer`
-    Unsafe unsafe = getUnsafe();
-    long addressInByteBuffer = address(byteBuffer);
-    long originalAddress = unsafe.getAddress(addressInByteBuffer);
-
-    RssShuffleUtils.destroyDirectByteBuffer(byteBuffer);
-
-    // The memory may not be released fast enough.
-    // If native pointer changes, `address` in `DirectByteBuffer` is invalid
-    while (unsafe.getAddress(addressInByteBuffer) == originalAddress) {
-      Thread.sleep(200);
-    }
-    boolean same = true;
-    byte[] read = new byte[size];
-    byteBuffer.get(read);
-    for (byte br : read) {
-      if (b != br) {
-        same = false;
-        break;
-      }
-    }
-    assertFalse(same);
-  }
-
-  private Unsafe getUnsafe() throws NoSuchFieldException, 
IllegalAccessException {
-    Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
-    unsafeField.setAccessible(true);
-    return (Unsafe) unsafeField.get(null);
-  }
-
-  private long address(ByteBuffer buffer) throws NoSuchFieldException, 
IllegalAccessException {
-    Field addressField = Buffer.class.getDeclaredField("address");
-    addressField.setAccessible(true);
-    return (long) addressField.get(buffer);
-  }
-}
diff --git 
a/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
 
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
new file mode 100644
index 00000000..cb2fdc6f
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/compression/CompressionTest.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common.compression;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang3.RandomUtils;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+
+import org.apache.uniffle.common.config.RssConf;
+
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+
+public class CompressionTest {
+
+  static List<Arguments> testCompression() {
+    int[] sizes = {1, 1024, 128 * 1024, 512 * 1024, 1024 * 1024, 4 * 1024 * 
1024};
+    Codec.Type[] types = {Codec.Type.ZSTD, Codec.Type.LZ4};
+
+    List<Arguments> arguments = new ArrayList<>();
+    for (int size : sizes) {
+      for (Codec.Type type : types) {
+        arguments.add(
+            Arguments.of(size, type)
+        );
+      }
+    }
+    return arguments;
+  }
+
+  @ParameterizedTest
+  @MethodSource
+  public void testCompression(int size, Codec.Type type) {
+    byte[] data = RandomUtils.nextBytes(size);
+    RssConf conf = new RssConf();
+    conf.set(COMPRESSION_TYPE, type);
+        
+    // case1: heap bytebuffer
+    Codec codec = Codec.newInstance(conf);
+    byte[] compressed = codec.compress(data);
+
+    ByteBuffer dest = ByteBuffer.allocate(size);
+    codec.decompress(ByteBuffer.wrap(compressed), size, dest, 0);
+
+    assertArrayEquals(data, dest.array());
+
+    // case2: non-heap bytebuffer
+    ByteBuffer src = ByteBuffer.allocateDirect(compressed.length);
+    src.put(compressed);
+    src.flip();
+    ByteBuffer dst = ByteBuffer.allocateDirect(size);
+    codec.decompress(src, size, dst, 0);
+    byte[] res = new byte[size];
+    dst.get(res);
+    assertArrayEquals(data, res);
+
+    // case3: use the recycled bytebuffer
+    ByteBuffer recycledDst = ByteBuffer.allocate(size + 10);
+    codec.decompress(ByteBuffer.wrap(compressed), size, recycledDst, 0);
+    recycledDst.get(res);
+    assertArrayEquals(data, res);
+  }
+}
diff --git a/docs/client_guide.md b/docs/client_guide.md
index 9b5a208a..c945802d 100644
--- a/docs/client_guide.md
+++ b/docs/client_guide.md
@@ -89,6 +89,8 @@ These configurations are shared by all types of clients.
 |<client_type>.rss.client.assignment.tags|-|The comma-separated list of tags 
for deciding assignment shuffle servers. Notice that the SHUFFLE_SERVER_VERSION 
will always as the assignment tag whether this conf is set or not|
 |<client_type>.rss.client.data.commit.pool.size|The number of assigned shuffle 
servers|The thread size for sending commit to shuffle servers|
 |<client_type>.rss.client.assignment.shuffle.nodes.max|-1|The number of 
required assignment shuffle servers. If it is less than 0 or equals to 0 or 
greater than the coordinator's config of "rss.coordinator.shuffle.nodes.max", 
it will use the size of "rss.coordinator.shuffle.nodes.max" default|
+|<client_type>.rss.client.io.compression.codec|lz4|The compression codec is 
used to compress the shuffle data. Default codec is `lz4`, `zstd` also can be 
used.|
+|<client_type>.rss.client.io.compression.zstd.level|3|The zstd compression 
level, the default level is 3|
 Notice:
 
 1. `<client_type>` should be `spark` or `mapreduce`
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
index 82586d91..f97f043e 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/RepartitionWithLocalFileRssTest.java
@@ -18,18 +18,25 @@
 package org.apache.uniffle.test;
 
 import java.io.File;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
+import java.util.concurrent.TimeUnit;
 
 import com.google.common.collect.Maps;
 import com.google.common.io.Files;
+import com.google.common.util.concurrent.Uninterruptibles;
 import org.apache.spark.SparkConf;
 import org.apache.spark.shuffle.RssSparkConfig;
 import org.junit.jupiter.api.BeforeAll;
 
+import org.apache.uniffle.common.compression.Codec;
 import org.apache.uniffle.coordinator.CoordinatorConf;
 import org.apache.uniffle.server.ShuffleServerConf;
 import org.apache.uniffle.storage.util.StorageType;
 
+import static org.apache.uniffle.common.config.RssClientConf.COMPRESSION_TYPE;
+
 public class RepartitionWithLocalFileRssTest extends RepartitionTest {
 
   @BeforeAll
@@ -53,4 +60,35 @@ public class RepartitionWithLocalFileRssTest extends 
RepartitionTest {
   @Override
   public void updateRssStorage(SparkConf sparkConf) {
   }
+
+  /**
+   * Test different compression types with localfile rss mode.
+   * @throws Exception
+   */
+  @Override
+  public void run() throws Exception {
+    String fileName = generateTestFile();
+    SparkConf sparkConf = createSparkConf();
+    Uninterruptibles.sleepUninterruptibly(2, TimeUnit.SECONDS);
+
+    List<Map> results = new ArrayList<>();
+    Map resultWithoutRss = runSparkApp(sparkConf, fileName);
+    results.add(resultWithoutRss);
+
+    updateSparkConfWithRss(sparkConf);
+    updateSparkConfCustomer(sparkConf);
+    for (Codec.Type type :
+        new Codec.Type[]{
+            Codec.Type.NOOP,
+            Codec.Type.ZSTD,
+            Codec.Type.LZ4}) {
+      sparkConf.set("spark." + COMPRESSION_TYPE.key().toLowerCase(), 
type.name());
+      Map resultWithRss = runSparkApp(sparkConf, fileName);
+      results.add(resultWithRss);
+    }
+
+    for (int i = 1; i < results.size(); i++) {
+      verifyTestResult(results.get(0), results.get(i));
+    }
+  }
 }
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
index 1ea90007..0cc5a17d 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/SparkIntegrationTestBase.java
@@ -99,7 +99,7 @@ public abstract class SparkIntegrationTestBase extends 
IntegrationTestBase {
     sparkConf.set(RssSparkConfig.RSS_HEARTBEAT_INTERVAL.key(), "2000");
   }
 
-  private void verifyTestResult(Map expected, Map actual) {
+  protected void verifyTestResult(Map expected, Map actual) {
     assertEquals(expected.size(), actual.size());
     for (Object expectedKey : expected.keySet()) {
       assertEquals(expected.get(expectedKey), actual.get(expectedKey));
diff --git a/pom.xml b/pom.xml
index c18d1be7..61e999c2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -75,6 +75,7 @@
     <spotbugs.version>4.7.0</spotbugs.version>
     <spotbugs-maven-plugin.version>4.7.0.0</spotbugs-maven-plugin.version>
     <system-rules.version>1.19.0</system-rules.version>
+    <zstd-jni.version>1.5.2-3</zstd-jni.version>
     <test.redirectToFile>true</test.redirectToFile>
     <trimStackTrace>false</trimStackTrace>
   </properties>
@@ -600,6 +601,12 @@
         <artifactId>mockito-core</artifactId>
         <version>${mockito.version}</version>
       </dependency>
+
+      <dependency>
+        <groupId>com.github.luben</groupId>
+        <artifactId>zstd-jni</artifactId>
+        <version>${zstd-jni.version}</version>
+      </dependency>
     </dependencies>
   </dependencyManagement>
 

Reply via email to