This is an automated email from the ASF dual-hosted git repository.

nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new ffdaef98c [CELEBORN-2097] Support Zstd Compression in CppClient
ffdaef98c is described below

commit ffdaef98c38b2d1494945f98dc09d7b20d3b29ac
Author: Jray <[email protected]>
AuthorDate: Fri Aug 29 18:58:22 2025 +0800

    [CELEBORN-2097] Support Zstd Compression in CppClient
    
    ### What changes were proposed in this pull request?
    This PR adds support for zstd compression in CppClient.
    
    ### Why are the changes needed?
    To support writing to Celeborn with CppClient.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By compilation and UTs.
    
    Closes #3454 from Jraaay/feat/cpp_client_zstd_compression.
    
    Authored-by: Jray <[email protected]>
    Signed-off-by: SteNicholas <[email protected]>
---
 cpp/celeborn/client/CMakeLists.txt                 |  3 +-
 cpp/celeborn/client/compress/Compressor.cpp        |  5 +-
 cpp/celeborn/client/compress/ZstdCompressor.cpp    | 76 ++++++++++++++++++++++
 .../compress/{Compressor.cpp => ZstdCompressor.h}  | 39 ++++++-----
 cpp/celeborn/client/tests/CMakeLists.txt           |  3 +-
 cpp/celeborn/client/tests/ZstdCompressorTest.cpp   | 76 ++++++++++++++++++++++
 cpp/celeborn/conf/CelebornConf.cpp                 |  6 ++
 cpp/celeborn/conf/CelebornConf.h                   |  5 ++
 8 files changed, 193 insertions(+), 20 deletions(-)

diff --git a/cpp/celeborn/client/CMakeLists.txt 
b/cpp/celeborn/client/CMakeLists.txt
index c5534a3a8..2586f6855 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -21,7 +21,8 @@ add_library(
         compress/Lz4Decompressor.cpp
         compress/ZstdDecompressor.cpp
         compress/Compressor.cpp
-        compress/Lz4Compressor.cpp)
+        compress/Lz4Compressor.cpp
+        compress/ZstdCompressor.cpp)
 
 target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
 
diff --git a/cpp/celeborn/client/compress/Compressor.cpp 
b/cpp/celeborn/client/compress/Compressor.cpp
index 849b1ff0d..cde6cad5a 100644
--- a/cpp/celeborn/client/compress/Compressor.cpp
+++ b/cpp/celeborn/client/compress/Compressor.cpp
@@ -18,6 +18,7 @@
 #include <stdexcept>
 
 #include "celeborn/client/compress/Lz4Compressor.h"
+#include "celeborn/client/compress/ZstdCompressor.h"
 #include "celeborn/utils/Exceptions.h"
 
 namespace celeborn {
@@ -31,8 +32,8 @@ std::unique_ptr<Compressor> Compressor::createCompressor(
     case protocol::CompressionCodec::LZ4:
       return std::make_unique<Lz4Compressor>();
     case protocol::CompressionCodec::ZSTD:
-      // TODO: impl zstd
-      CELEBORN_FAIL("Compression codec ZSTD is not supported.");
+      return std::make_unique<ZstdCompressor>(
+          conf.shuffleCompressionZstdCompressLevel());
     default:
       CELEBORN_FAIL("Unknown compression codec.");
   }
diff --git a/cpp/celeborn/client/compress/ZstdCompressor.cpp 
b/cpp/celeborn/client/compress/ZstdCompressor.cpp
new file mode 100644
index 000000000..3ce248644
--- /dev/null
+++ b/cpp/celeborn/client/compress/ZstdCompressor.cpp
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <zlib.h>
+#include <zstd.h>
+
+#include "celeborn/client/compress/ZstdCompressor.h"
+
+namespace celeborn {
+namespace client {
+namespace compress {
+
+ZstdCompressor::ZstdCompressor(const int compressionLevel)
+    : compressionLevel_(compressionLevel) {}
+
+size_t ZstdCompressor::compress(
+    const uint8_t* src,
+    const int srcOffset,
+    const int srcLength,
+    uint8_t* dst,
+    const int dstOffset) {
+  const auto srcPtr = src + srcOffset;
+  const auto dstPtr = dst + dstOffset;
+  const auto dstDataPtr = dstPtr + kHeaderLength;
+
+  uLong check = crc32(0L, Z_NULL, 0);
+  check = crc32(check, srcPtr, srcLength);
+
+  std::copy_n(kMagic, kMagicLength, dstPtr);
+
+  size_t compressedLength = ZSTD_compress(
+      dstDataPtr,
+      ZSTD_compressBound(srcLength),
+      srcPtr,
+      srcLength,
+      compressionLevel_);
+
+  int compressionMethod;
+  if (ZSTD_isError(compressedLength) ||
+      compressedLength >= static_cast<size_t>(srcLength)) {
+    compressionMethod = kCompressionMethodRaw;
+    compressedLength = srcLength;
+    std::copy_n(srcPtr, srcLength, dstDataPtr);
+  } else {
+    compressionMethod = kCompressionMethodZstd;
+  }
+
+  dstPtr[kMagicLength] = static_cast<uint8_t>(compressionMethod);
+  writeIntLE(compressedLength, dstPtr, kMagicLength + 1);
+  writeIntLE(srcLength, dstPtr, kMagicLength + 5);
+  writeIntLE(static_cast<int>(check), dstPtr, kMagicLength + 9);
+
+  return kHeaderLength + compressedLength;
+}
+
+size_t ZstdCompressor::getDstCapacity(const int length) {
+  return ZSTD_compressBound(length) + kHeaderLength;
+}
+
+} // namespace compress
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/compress/Compressor.cpp 
b/cpp/celeborn/client/compress/ZstdCompressor.h
similarity index 60%
copy from cpp/celeborn/client/compress/Compressor.cpp
copy to cpp/celeborn/client/compress/ZstdCompressor.h
index 849b1ff0d..2fd8aa02b 100644
--- a/cpp/celeborn/client/compress/Compressor.cpp
+++ b/cpp/celeborn/client/compress/ZstdCompressor.h
@@ -15,28 +15,35 @@
  * limitations under the License.
  */
 
-#include <stdexcept>
+#pragma once
 
-#include "celeborn/client/compress/Lz4Compressor.h"
-#include "celeborn/utils/Exceptions.h"
+#include "celeborn/client/compress/Compressor.h"
+#include "celeborn/client/compress/ZstdTrait.h"
 
 namespace celeborn {
 namespace client {
 namespace compress {
 
-std::unique_ptr<Compressor> Compressor::createCompressor(
-    const conf::CelebornConf& conf) {
-  const auto codec = conf.shuffleCompressionCodec();
-  switch (codec) {
-    case protocol::CompressionCodec::LZ4:
-      return std::make_unique<Lz4Compressor>();
-    case protocol::CompressionCodec::ZSTD:
-      // TODO: impl zstd
-      CELEBORN_FAIL("Compression codec ZSTD is not supported.");
-    default:
-      CELEBORN_FAIL("Unknown compression codec.");
-  }
-}
+class ZstdCompressor final : public Compressor, ZstdTrait {
+ public:
+  explicit ZstdCompressor(int compressionLevel);
+  ~ZstdCompressor() override = default;
+
+  size_t compress(
+      const uint8_t* src,
+      int srcOffset,
+      int srcLength,
+      uint8_t* dst,
+      int dstOffset) override;
+
+  size_t getDstCapacity(int length) override;
+
+  ZstdCompressor(const ZstdCompressor&) = delete;
+  ZstdCompressor& operator=(const ZstdCompressor&) = delete;
+
+ private:
+  const int compressionLevel_;
+};
 
 } // namespace compress
 } // namespace client
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt 
b/cpp/celeborn/client/tests/CMakeLists.txt
index 9ac740474..d8a98e2b6 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -18,7 +18,8 @@ add_executable(
         WorkerPartitionReaderTest.cpp
         Lz4DecompressorTest.cpp
         ZstdDecompressorTest.cpp
-        Lz4CompressorTest.cpp)
+        Lz4CompressorTest.cpp
+        ZstdCompressorTest.cpp)
 
 add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
 
diff --git a/cpp/celeborn/client/tests/ZstdCompressorTest.cpp 
b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp
new file mode 100644
index 000000000..c4cf9ce7c
--- /dev/null
+++ b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/compress/ZstdCompressor.h"
+#include "client/compress/ZstdDecompressor.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::protocol;
+
+TEST(ZstdCompressorTest, CompressWithZstd) {
+  for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
+    compress::ZstdCompressor compressor(compressionLevel);
+    const std::string toCompressData =
+        "Helloooooooooooo Celeborn!!!!!!!!!!!!!!";
+
+    const auto maxLength = compressor.getDstCapacity(toCompressData.size());
+    std::vector<uint8_t> compressedData(maxLength);
+    compressor.compress(
+        reinterpret_cast<const uint8_t*>(toCompressData.data()),
+        0,
+        toCompressData.size(),
+        compressedData.data(),
+        0);
+
+    compress::ZstdDecompressor decompressor;
+    const auto oriLength = decompressor.getOriginalLen(compressedData.data());
+    std::vector<uint8_t> decompressedData(oriLength + 1);
+    decompressedData[oriLength] = '\0';
+    const bool success = decompressor.decompress(
+        compressedData.data(), decompressedData.data(), 0);
+    EXPECT_TRUE(success);
+    EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()), 
toCompressData);
+  }
+}
+
+TEST(ZstdCompressorTest, CompressWithRaw) {
+  for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
+    compress::ZstdCompressor compressor(compressionLevel);
+    const std::string toCompressData = "Hello Celeborn!";
+
+    const auto maxLength = compressor.getDstCapacity(toCompressData.size());
+    std::vector<uint8_t> compressedData(maxLength);
+    compressor.compress(
+        reinterpret_cast<const uint8_t*>(toCompressData.data()),
+        0,
+        toCompressData.size(),
+        compressedData.data(),
+        0);
+
+    compress::ZstdDecompressor decompressor;
+    const auto oriLength = decompressor.getOriginalLen(compressedData.data());
+    std::vector<uint8_t> decompressedData(oriLength + 1);
+    decompressedData[oriLength] = '\0';
+    const bool success = decompressor.decompress(
+        compressedData.data(), decompressedData.data(), 0);
+    EXPECT_TRUE(success);
+    EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()), 
toCompressData);
+  }
+}
diff --git a/cpp/celeborn/conf/CelebornConf.cpp 
b/cpp/celeborn/conf/CelebornConf.cpp
index 4f85c19a0..e21d39032 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -143,6 +143,7 @@ const std::unordered_map<std::string, 
folly::Optional<std::string>>
         STR_PROP(
             kShuffleCompressionCodec,
             protocol::toString(protocol::CompressionCodec::NONE)),
+        NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
         // NUM_PROP(kNumExample, 50'000),
         // BOOL_PROP(kBoolExample, false),
 };
@@ -210,5 +211,10 @@ protocol::CompressionCodec 
CelebornConf::shuffleCompressionCodec() const {
   return protocol::toCompressionCodec(
       optionalProperty(kShuffleCompressionCodec).value());
 }
+
+int CelebornConf::shuffleCompressionZstdCompressLevel() const {
+  return std::stoi(
+      optionalProperty(kShuffleCompressionZstdCompressLevel).value());
+}
 } // namespace conf
 } // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 783bc96e1..5aa3c6f9e 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -64,6 +64,9 @@ class CelebornConf : public BaseConf {
   static constexpr std::string_view kShuffleCompressionCodec{
       "celeborn.client.shuffle.compression.codec"};
 
+  static constexpr std::string_view kShuffleCompressionZstdCompressLevel{
+      "celeborn.client.shuffle.compression.zstd.level"};
+
   CelebornConf();
 
   CelebornConf(const std::string& filename);
@@ -89,6 +92,8 @@ class CelebornConf : public BaseConf {
   int clientFetchMaxReqsInFlight() const;
 
   protocol::CompressionCodec shuffleCompressionCodec() const;
+
+  int shuffleCompressionZstdCompressLevel() const;
 };
 } // namespace conf
 } // namespace celeborn

Reply via email to