hhr293 commented on code in PR #12299:
URL: https://github.com/apache/gluten/pull/12299#discussion_r3456674114
##########
cpp/core/utils/tac/ffor.hpp:
##########
@@ -418,62 +452,53 @@ inline size_t compress64(const uint64_t* input, size_t
num, uint8_t* output) {
}
// Template-based decompress with alignment dispatch.
+// decodeBlock requires aligned uint64_t* input; when the caller's input
+// buffer is unaligned, we stage through tmpIn and memcpy in.
template <bool InAligned, bool OutAligned>
inline size_t decompress64Impl(const uint8_t* input, size_t inputSize,
uint64_t* output) {
- alignas(64) uint64_t tmpIn[kMaxValuesPerBlock];
+ alignas(64) uint64_t tmpIn[kMaxValuesPerBlock + 2];
alignas(64) uint64_t tmpOut[kMaxValuesPerBlock];
const uint8_t* inPtr = input;
const uint8_t* inEnd = input + inputSize;
size_t nDecoded = 0;
while (inPtr + kHeaderSize <= inEnd) {
- uint8_t bw;
- uint8_t count;
- uint64_t base;
- readHeader(inPtr, bw, count, base);
- inPtr += kHeaderSize;
-
- if (bw == kBwTailMarker) {
- if (count > 0) {
- // memcpy handles any alignment, no special case needed.
- std::memcpy(reinterpret_cast<uint8_t*>(output) + nDecoded *
sizeof(uint64_t), inPtr, count * sizeof(uint64_t));
+ if (inPtr[0] == kBwTailMarker) {
+ const uint8_t count = inPtr[1];
+ inPtr += kHeaderSize;
+ const size_t tailBytes = count * sizeof(uint64_t);
+ if (count > 0 && inPtr + tailBytes <= inEnd) {
+ std::memcpy(
+ reinterpret_cast<uint8_t*>(output) + nDecoded * sizeof(uint64_t),
inPtr, tailBytes);
nDecoded += count;
}
break;
}
-
- size_t blockVals = static_cast<size_t>(count) * kLanes;
- size_t compBytes = compressedWords(blockVals, bw) * sizeof(uint64_t);
-
- if (inPtr + compBytes > inEnd) {
+ const size_t blockVals = static_cast<size_t>(inPtr[1]) * kLanes;
+ if (blockVals == 0 || blockVals > kMaxValuesPerBlock) {
break;
}
+ const size_t remaining = static_cast<size_t>(inEnd - inPtr);
+ uint64_t* decDst = OutAligned ? output + nDecoded : tmpOut;
- // Decode: pick aligned src/dst.
- const uint64_t* decIn;
+ size_t consumed;
if constexpr (InAligned) {
- decIn = reinterpret_cast<const uint64_t*>(inPtr);
+ consumed = decodeBlock(reinterpret_cast<const uint64_t*>(inPtr),
remaining, blockVals, decDst);
} else {
- std::memcpy(tmpIn, inPtr, compBytes);
- decIn = tmpIn;
+ const size_t n = std::min(remaining, sizeof(tmpIn));
+ std::memcpy(tmpIn, inPtr, n);
+ consumed = decodeBlock(tmpIn, n, blockVals, decDst);
}
-
- uint64_t* decOut;
- if constexpr (OutAligned) {
- decOut = output + nDecoded;
- } else {
- decOut = tmpOut;
+ if (consumed == 0) {
+ break;
}
-
- decodeRt(decIn, decOut, base, blockVals, bw);
+ inPtr += consumed;
if constexpr (!OutAligned) {
std::memcpy(
reinterpret_cast<uint8_t*>(output) + nDecoded * sizeof(uint64_t),
tmpOut, blockVals * sizeof(uint64_t));
}
-
- inPtr += compBytes;
nDecoded += blockVals;
}
Review Comment:
fix it in a new commit
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]