This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new ffdaef98c [CELEBORN-2097] Support Zstd Compression in CppClient
ffdaef98c is described below
commit ffdaef98c38b2d1494945f98dc09d7b20d3b29ac
Author: Jray <[email protected]>
AuthorDate: Fri Aug 29 18:58:22 2025 +0800
[CELEBORN-2097] Support Zstd Compression in CppClient
### What changes were proposed in this pull request?
This PR adds support for zstd compression in CppClient.
### Why are the changes needed?
To support writing to Celeborn with CppClient.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By compilation and UTs.
Closes #3454 from Jraaay/feat/cpp_client_zstd_compression.
Authored-by: Jray <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
cpp/celeborn/client/CMakeLists.txt | 3 +-
cpp/celeborn/client/compress/Compressor.cpp | 5 +-
cpp/celeborn/client/compress/ZstdCompressor.cpp | 76 ++++++++++++++++++++++
.../compress/{Compressor.cpp => ZstdCompressor.h} | 39 ++++++-----
cpp/celeborn/client/tests/CMakeLists.txt | 3 +-
cpp/celeborn/client/tests/ZstdCompressorTest.cpp | 76 ++++++++++++++++++++++
cpp/celeborn/conf/CelebornConf.cpp | 6 ++
cpp/celeborn/conf/CelebornConf.h | 5 ++
8 files changed, 193 insertions(+), 20 deletions(-)
diff --git a/cpp/celeborn/client/CMakeLists.txt
b/cpp/celeborn/client/CMakeLists.txt
index c5534a3a8..2586f6855 100644
--- a/cpp/celeborn/client/CMakeLists.txt
+++ b/cpp/celeborn/client/CMakeLists.txt
@@ -21,7 +21,8 @@ add_library(
compress/Lz4Decompressor.cpp
compress/ZstdDecompressor.cpp
compress/Compressor.cpp
- compress/Lz4Compressor.cpp)
+ compress/Lz4Compressor.cpp
+ compress/ZstdCompressor.cpp)
target_include_directories(client PUBLIC ${CMAKE_BINARY_DIR})
diff --git a/cpp/celeborn/client/compress/Compressor.cpp
b/cpp/celeborn/client/compress/Compressor.cpp
index 849b1ff0d..cde6cad5a 100644
--- a/cpp/celeborn/client/compress/Compressor.cpp
+++ b/cpp/celeborn/client/compress/Compressor.cpp
@@ -18,6 +18,7 @@
#include <stdexcept>
#include "celeborn/client/compress/Lz4Compressor.h"
+#include "celeborn/client/compress/ZstdCompressor.h"
#include "celeborn/utils/Exceptions.h"
namespace celeborn {
@@ -31,8 +32,8 @@ std::unique_ptr<Compressor> Compressor::createCompressor(
case protocol::CompressionCodec::LZ4:
return std::make_unique<Lz4Compressor>();
case protocol::CompressionCodec::ZSTD:
- // TODO: impl zstd
- CELEBORN_FAIL("Compression codec ZSTD is not supported.");
+ return std::make_unique<ZstdCompressor>(
+ conf.shuffleCompressionZstdCompressLevel());
default:
CELEBORN_FAIL("Unknown compression codec.");
}
diff --git a/cpp/celeborn/client/compress/ZstdCompressor.cpp
b/cpp/celeborn/client/compress/ZstdCompressor.cpp
new file mode 100644
index 000000000..3ce248644
--- /dev/null
+++ b/cpp/celeborn/client/compress/ZstdCompressor.cpp
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <zlib.h>
+#include <zstd.h>
+
+#include "celeborn/client/compress/ZstdCompressor.h"
+
+namespace celeborn {
+namespace client {
+namespace compress {
+
+ZstdCompressor::ZstdCompressor(const int compressionLevel)
+ : compressionLevel_(compressionLevel) {}
+
+size_t ZstdCompressor::compress(
+ const uint8_t* src,
+ const int srcOffset,
+ const int srcLength,
+ uint8_t* dst,
+ const int dstOffset) {
+ const auto srcPtr = src + srcOffset;
+ const auto dstPtr = dst + dstOffset;
+ const auto dstDataPtr = dstPtr + kHeaderLength;
+
+ uLong check = crc32(0L, Z_NULL, 0);
+ check = crc32(check, srcPtr, srcLength);
+
+ std::copy_n(kMagic, kMagicLength, dstPtr);
+
+ size_t compressedLength = ZSTD_compress(
+ dstDataPtr,
+ ZSTD_compressBound(srcLength),
+ srcPtr,
+ srcLength,
+ compressionLevel_);
+
+ int compressionMethod;
+ if (ZSTD_isError(compressedLength) ||
+ compressedLength >= static_cast<size_t>(srcLength)) {
+ compressionMethod = kCompressionMethodRaw;
+ compressedLength = srcLength;
+ std::copy_n(srcPtr, srcLength, dstDataPtr);
+ } else {
+ compressionMethod = kCompressionMethodZstd;
+ }
+
+ dstPtr[kMagicLength] = static_cast<uint8_t>(compressionMethod);
+ writeIntLE(compressedLength, dstPtr, kMagicLength + 1);
+ writeIntLE(srcLength, dstPtr, kMagicLength + 5);
+ writeIntLE(static_cast<int>(check), dstPtr, kMagicLength + 9);
+
+ return kHeaderLength + compressedLength;
+}
+
+size_t ZstdCompressor::getDstCapacity(const int length) {
+ return ZSTD_compressBound(length) + kHeaderLength;
+}
+
+} // namespace compress
+} // namespace client
+} // namespace celeborn
diff --git a/cpp/celeborn/client/compress/Compressor.cpp
b/cpp/celeborn/client/compress/ZstdCompressor.h
similarity index 60%
copy from cpp/celeborn/client/compress/Compressor.cpp
copy to cpp/celeborn/client/compress/ZstdCompressor.h
index 849b1ff0d..2fd8aa02b 100644
--- a/cpp/celeborn/client/compress/Compressor.cpp
+++ b/cpp/celeborn/client/compress/ZstdCompressor.h
@@ -15,28 +15,35 @@
* limitations under the License.
*/
-#include <stdexcept>
+#pragma once
-#include "celeborn/client/compress/Lz4Compressor.h"
-#include "celeborn/utils/Exceptions.h"
+#include "celeborn/client/compress/Compressor.h"
+#include "celeborn/client/compress/ZstdTrait.h"
namespace celeborn {
namespace client {
namespace compress {
-std::unique_ptr<Compressor> Compressor::createCompressor(
- const conf::CelebornConf& conf) {
- const auto codec = conf.shuffleCompressionCodec();
- switch (codec) {
- case protocol::CompressionCodec::LZ4:
- return std::make_unique<Lz4Compressor>();
- case protocol::CompressionCodec::ZSTD:
- // TODO: impl zstd
- CELEBORN_FAIL("Compression codec ZSTD is not supported.");
- default:
- CELEBORN_FAIL("Unknown compression codec.");
- }
-}
+class ZstdCompressor final : public Compressor, ZstdTrait {
+ public:
+ explicit ZstdCompressor(int compressionLevel);
+ ~ZstdCompressor() override = default;
+
+ size_t compress(
+ const uint8_t* src,
+ int srcOffset,
+ int srcLength,
+ uint8_t* dst,
+ int dstOffset) override;
+
+ size_t getDstCapacity(int length) override;
+
+ ZstdCompressor(const ZstdCompressor&) = delete;
+ ZstdCompressor& operator=(const ZstdCompressor&) = delete;
+
+ private:
+ const int compressionLevel_;
+};
} // namespace compress
} // namespace client
diff --git a/cpp/celeborn/client/tests/CMakeLists.txt
b/cpp/celeborn/client/tests/CMakeLists.txt
index 9ac740474..d8a98e2b6 100644
--- a/cpp/celeborn/client/tests/CMakeLists.txt
+++ b/cpp/celeborn/client/tests/CMakeLists.txt
@@ -18,7 +18,8 @@ add_executable(
WorkerPartitionReaderTest.cpp
Lz4DecompressorTest.cpp
ZstdDecompressorTest.cpp
- Lz4CompressorTest.cpp)
+ Lz4CompressorTest.cpp
+ ZstdCompressorTest.cpp)
add_test(NAME celeborn_client_test COMMAND celeborn_client_test)
diff --git a/cpp/celeborn/client/tests/ZstdCompressorTest.cpp
b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp
new file mode 100644
index 000000000..c4cf9ce7c
--- /dev/null
+++ b/cpp/celeborn/client/tests/ZstdCompressorTest.cpp
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include <gtest/gtest.h>
+
+#include "celeborn/client/compress/ZstdCompressor.h"
+#include "client/compress/ZstdDecompressor.h"
+
+using namespace celeborn;
+using namespace celeborn::client;
+using namespace celeborn::protocol;
+
+TEST(ZstdCompressorTest, CompressWithZstd) {
+ for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
+ compress::ZstdCompressor compressor(compressionLevel);
+ const std::string toCompressData =
+ "Helloooooooooooo Celeborn!!!!!!!!!!!!!!";
+
+ const auto maxLength = compressor.getDstCapacity(toCompressData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+ compressor.compress(
+ reinterpret_cast<const uint8_t*>(toCompressData.data()),
+ 0,
+ toCompressData.size(),
+ compressedData.data(),
+ 0);
+
+ compress::ZstdDecompressor decompressor;
+ const auto oriLength = decompressor.getOriginalLen(compressedData.data());
+ std::vector<uint8_t> decompressedData(oriLength + 1);
+ decompressedData[oriLength] = '\0';
+ const bool success = decompressor.decompress(
+ compressedData.data(), decompressedData.data(), 0);
+ EXPECT_TRUE(success);
+ EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()),
toCompressData);
+ }
+}
+
+TEST(ZstdCompressorTest, CompressWithRaw) {
+ for (int compressionLevel = -5; compressionLevel <= 22; compressionLevel++) {
+ compress::ZstdCompressor compressor(compressionLevel);
+ const std::string toCompressData = "Hello Celeborn!";
+
+ const auto maxLength = compressor.getDstCapacity(toCompressData.size());
+ std::vector<uint8_t> compressedData(maxLength);
+ compressor.compress(
+ reinterpret_cast<const uint8_t*>(toCompressData.data()),
+ 0,
+ toCompressData.size(),
+ compressedData.data(),
+ 0);
+
+ compress::ZstdDecompressor decompressor;
+ const auto oriLength = decompressor.getOriginalLen(compressedData.data());
+ std::vector<uint8_t> decompressedData(oriLength + 1);
+ decompressedData[oriLength] = '\0';
+ const bool success = decompressor.decompress(
+ compressedData.data(), decompressedData.data(), 0);
+ EXPECT_TRUE(success);
+ EXPECT_EQ(reinterpret_cast<char*>(decompressedData.data()),
toCompressData);
+ }
+}
diff --git a/cpp/celeborn/conf/CelebornConf.cpp
b/cpp/celeborn/conf/CelebornConf.cpp
index 4f85c19a0..e21d39032 100644
--- a/cpp/celeborn/conf/CelebornConf.cpp
+++ b/cpp/celeborn/conf/CelebornConf.cpp
@@ -143,6 +143,7 @@ const std::unordered_map<std::string,
folly::Optional<std::string>>
STR_PROP(
kShuffleCompressionCodec,
protocol::toString(protocol::CompressionCodec::NONE)),
+ NUM_PROP(kShuffleCompressionZstdCompressLevel, 1),
// NUM_PROP(kNumExample, 50'000),
// BOOL_PROP(kBoolExample, false),
};
@@ -210,5 +211,10 @@ protocol::CompressionCodec
CelebornConf::shuffleCompressionCodec() const {
return protocol::toCompressionCodec(
optionalProperty(kShuffleCompressionCodec).value());
}
+
+int CelebornConf::shuffleCompressionZstdCompressLevel() const {
+ return std::stoi(
+ optionalProperty(kShuffleCompressionZstdCompressLevel).value());
+}
} // namespace conf
} // namespace celeborn
diff --git a/cpp/celeborn/conf/CelebornConf.h b/cpp/celeborn/conf/CelebornConf.h
index 783bc96e1..5aa3c6f9e 100644
--- a/cpp/celeborn/conf/CelebornConf.h
+++ b/cpp/celeborn/conf/CelebornConf.h
@@ -64,6 +64,9 @@ class CelebornConf : public BaseConf {
static constexpr std::string_view kShuffleCompressionCodec{
"celeborn.client.shuffle.compression.codec"};
+ static constexpr std::string_view kShuffleCompressionZstdCompressLevel{
+ "celeborn.client.shuffle.compression.zstd.level"};
+
CelebornConf();
CelebornConf(const std::string& filename);
@@ -89,6 +92,8 @@ class CelebornConf : public BaseConf {
int clientFetchMaxReqsInFlight() const;
protocol::CompressionCodec shuffleCompressionCodec() const;
+
+ int shuffleCompressionZstdCompressLevel() const;
};
} // namespace conf
} // namespace celeborn