- Allocating VLAs on the stack (or using alloca()) for large sizes
   could exceed the stack limit;

 - It's easier to isolate these buffers on the heap for code sanitizers
   to detect potential bugs.

Signed-off-by: Gao Xiang <[email protected]>
---
 lib/compressor_libzstd.c | 83 +++++++++++++++++++++++++++++-----------
 1 file changed, 60 insertions(+), 23 deletions(-)

diff --git a/lib/compressor_libzstd.c b/lib/compressor_libzstd.c
index e53b1a63..3233d723 100644
--- a/lib/compressor_libzstd.c
+++ b/lib/compressor_libzstd.c
@@ -4,18 +4,24 @@
 #include "erofs/config.h"
 #include <zstd.h>
 #include <zstd_errors.h>
-#include <alloca.h>
+#include <stdlib.h>
 #include "compressor.h"
 #include "erofs/atomic.h"
 
+struct erofs_libzstd_context {
+       ZSTD_CCtx *cctx;
+       u8 *fitblk_buffer;
+       unsigned int fitblk_bufsiz;
+};
+
 static int libzstd_compress(const struct erofs_compress *c,
                            const void *src, unsigned int srcsize,
                            void *dst, unsigned int dstcapacity)
 {
-       ZSTD_CCtx *cctx = c->private_data;
+       struct erofs_libzstd_context *ctx = c->private_data;
        size_t csize;
 
-       csize = ZSTD_compress2(cctx, dst, dstcapacity, src, srcsize);
+       csize = ZSTD_compress2(ctx->cctx, dst, dstcapacity, src, srcsize);
        if (ZSTD_isError(csize)) {
                if (ZSTD_getErrorCode(csize) == ZSTD_error_dstSize_tooSmall)
                        return -ENOSPC;
@@ -29,12 +35,20 @@ static int libzstd_compress_destsize(const struct 
erofs_compress *c,
                                     const void *src, unsigned int *srcsize,
                                     void *dst, unsigned int dstsize)
 {
-       ZSTD_CCtx *cctx = c->private_data;
+       struct erofs_libzstd_context *ctx = c->private_data;
        size_t l = 0;           /* largest input that fits so far */
        size_t l_csize = 0;
        size_t r = *srcsize + 1; /* smallest input that doesn't fit so far */
        size_t m;
-       u8 *fitblk_buffer = alloca(dstsize + 32);
+
+       if (dstsize + 32 > ctx->fitblk_bufsiz) {
+               u8 *buf = realloc(ctx->fitblk_buffer, dstsize + 32);
+
+               if (!buf)
+                       return -ENOMEM;
+               ctx->fitblk_bufsiz = dstsize + 32;
+               ctx->fitblk_buffer = buf;
+       }
 
        m = dstsize * 4;
        for (;;) {
@@ -43,7 +57,7 @@ static int libzstd_compress_destsize(const struct 
erofs_compress *c,
                m = max(m, l + 1);
                m = min(m, r - 1);
 
-               csize = ZSTD_compress2(cctx, fitblk_buffer,
+               csize = ZSTD_compress2(ctx->cctx, ctx->fitblk_buffer,
                                       dstsize + 32, src, m);
                if (ZSTD_isError(csize)) {
                        if (ZSTD_getErrorCode(csize) == 
ZSTD_error_dstSize_tooSmall)
@@ -53,7 +67,7 @@ static int libzstd_compress_destsize(const struct 
erofs_compress *c,
 
                if (csize > 0 && csize <= dstsize) {
                        /* Fits */
-                       memcpy(dst, fitblk_buffer, csize);
+                       memcpy(dst, ctx->fitblk_buffer, csize);
                        l = m;
                        l_csize = csize;
                        if (r <= l + 1 || csize + 1 >= dstsize)
@@ -78,9 +92,14 @@ doesnt_fit:
 
 static int compressor_libzstd_exit(struct erofs_compress *c)
 {
-       if (!c->private_data)
+       struct erofs_libzstd_context *ctx = c->private_data;
+
+       if (!ctx)
                return -EINVAL;
-       ZSTD_freeCCtx(c->private_data);
+
+       free(ctx->fitblk_buffer);
+       ZSTD_freeCCtx(ctx->cctx);
+       free(ctx);
        return 0;
 }
 
@@ -118,27 +137,41 @@ static int erofs_compressor_libzstd_setdictsize(struct 
erofs_compress *c,
 
 static int compressor_libzstd_init(struct erofs_compress *c)
 {
+       struct erofs_libzstd_context *ctx = c->private_data;
        static erofs_atomic_bool_t __warnonce;
-       ZSTD_CCtx *cctx = c->private_data;
-       size_t err;
+       ZSTD_CCtx *cctx;
+       size_t errcode;
+       int err;
 
-       ZSTD_freeCCtx(cctx);
+       if (ctx) {
+               ZSTD_freeCCtx(ctx->cctx);
+               ctx->cctx = NULL;
+               c->private_data = NULL;
+       } else {
+               ctx = calloc(1, sizeof(*ctx));
+               if (!ctx)
+                       return -ENOMEM;
+       }
        cctx = ZSTD_createCCtx();
-       if (!cctx)
-               return -ENOMEM;
+       if (!cctx) {
+               err = -ENOMEM;
+               goto out_err;
+       }
 
-       err = ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 
c->compression_level);
-       if (ZSTD_isError(err)) {
+       err = -EINVAL;
+       errcode = ZSTD_CCtx_setParameter(cctx, ZSTD_c_compressionLevel, 
c->compression_level);
+       if (ZSTD_isError(errcode)) {
                erofs_err("failed to set compression level: %s",
-                         ZSTD_getErrorName(err));
-               return -EINVAL;
+                         ZSTD_getErrorName(errcode));
+               goto out_err;
        }
-       err = ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, 
ilog2(c->dict_size));
-       if (ZSTD_isError(err)) {
-               erofs_err("failed to set window log: %s", 
ZSTD_getErrorName(err));
-               return -EINVAL;
+       errcode = ZSTD_CCtx_setParameter(cctx, ZSTD_c_windowLog, 
ilog2(c->dict_size));
+       if (ZSTD_isError(errcode)) {
+               erofs_err("failed to set window log: %s", 
ZSTD_getErrorName(errcode));
+               goto out_err;
        }
-       c->private_data = cctx;
+       ctx->cctx = cctx;
+       c->private_data = ctx;
 
        if (!erofs_atomic_test_and_set(&__warnonce)) {
                erofs_warn("EXPERIMENTAL libzstd compressor in use. Note that 
`fitblk` isn't supported by upstream zstd for now.");
@@ -146,6 +179,10 @@ static int compressor_libzstd_init(struct erofs_compress 
*c)
                erofs_info("You could clarify further needs in zstd repository 
<https://github.com/facebook/zstd/issues> for reference too.");
        }
        return 0;
+out_err:
+       ZSTD_freeCCtx(cctx);
+       free(ctx);
+       return err;
 }
 
 const struct erofs_compressor erofs_compressor_libzstd = {
-- 
2.43.5


Reply via email to