Now with zstd compressed kernel & initrd upstream, we would rather
compress everything with one type of compressor, so I added support
for zstd compressed firmware loading, too.

Tested on x86-64, sparc64 and mips64.

Signed-off-by: René Rebe <r...@exactcode.de>

diff --git a/drivers/base/firmware_loader/Kconfig 
b/drivers/base/firmware_loader/Kconfig
index 5b24f3959255..30d440bec257 100644
--- a/drivers/base/firmware_loader/Kconfig
+++ b/drivers/base/firmware_loader/Kconfig
@@ -169,6 +169,16 @@ config FW_LOADER_COMPRESS
          be compressed with either none or crc32 integrity check type (pass
          "-C crc32" option to xz command).
 
+config FW_LOADER_COMPRESS_ZSTD
+       bool "Enable Zstd compressed firmware support"
+       select FW_LOADER_PAGED_BUF
+       select ZSTD_DECOMPRESS
+       help
+         This option enables the support for loading Zstd compressed firmware
+         files. The caller of firmware API receives the decompressed file
+         content. The compressed file is loaded as a fallback, only after
+         loading the raw file failed at first.
+
 config FW_CACHE
        bool "Enable firmware caching during suspend"
        depends on PM_SLEEP
diff --git a/drivers/base/firmware_loader/main.c 
b/drivers/base/firmware_loader/main.c
index 9da0c9d5f538..cd3bd6f9a64b 100644
--- a/drivers/base/firmware_loader/main.c
+++ b/drivers/base/firmware_loader/main.c
@@ -33,7 +33,9 @@
 #include <linux/syscore_ops.h>
 #include <linux/reboot.h>
 #include <linux/security.h>
+#include <linux/decompress/mm.h>
 #include <linux/xz.h>
+#include <linux/zstd.h>
 
 #include <generated/utsrelease.h>
 
@@ -435,6 +437,185 @@ static int fw_decompress_xz(struct device *dev, struct 
fw_priv *fw_priv,
 }
 #endif /* CONFIG_FW_LOADER_COMPRESS */
 
+/*
+ * Zstd-compressed firmware support
+ */
+#ifdef CONFIG_FW_LOADER_COMPRESS_ZSTD
+/* show an error and return the standard error code */
+static int handle_zstd_error(size_t ret)
+{
+       const int err = ZSTD_getErrorCode(ret);
+
+       if (!ZSTD_isError(ret))
+               return 0;
+
+       switch (err) {
+       case ZSTD_error_memory_allocation:
+               printk("ZSTD decompressor ran out of memory");
+               break;
+       case ZSTD_error_prefix_unknown:
+               printk("Input is not in the ZSTD format (wrong magic bytes)");
+               break;
+       case ZSTD_error_dstSize_tooSmall:
+       case ZSTD_error_corruption_detected:
+       case ZSTD_error_checksum_wrong:
+               printk("ZSTD-compressed data is corrupt");
+               break;
+       default:
+               printk("ZSTD-compressed data is probably corrupt");
+               break;
+       }
+       return -1;
+}
+
+/* single-shot decompression onto the pre-allocated buffer */
+static int fw_decompress_zstd_single(struct device *dev, struct fw_priv 
*fw_priv,
+                                  size_t in_size, const void *in_buffer)
+{
+       const size_t wksp_size = ZSTD_DCtxWorkspaceBound();
+       void *wksp = large_malloc(wksp_size);
+       ZSTD_DCtx *dctx = ZSTD_initDCtx(wksp, wksp_size);
+       int err;
+       size_t ret;
+
+       if (dctx == NULL) {
+               dev_warn(dev, "Out of memory while allocating ZSTD_DCtx");
+               err = -1;
+               goto out;
+       }
+       /* Find out how large the frame actually is, there may be junk at
+        * the end of the frame that ZSTD_decompressDCtx() can't handle.
+        */
+       ret = ZSTD_findFrameCompressedSize(in_buffer, in_size);
+       err = handle_zstd_error(ret);
+       if (err)
+               goto out;
+       in_size = (long)ret;
+
+       ret = ZSTD_decompressDCtx(dctx, fw_priv->data, fw_priv->allocated_size, 
in_buffer, in_size);
+       err = handle_zstd_error(ret);
+       if (err)
+               goto out;
+
+       fw_priv->size = ret;
+
+out:
+       if (wksp != NULL)
+               large_free(wksp);
+       return err;
+}
+
+/* decompression on paged buffer and map it */
+static int fw_decompress_zstd_pages(struct device *dev, struct fw_priv 
*fw_priv,
+                                 size_t in_size, const void *in_buffer)
+{
+       ZSTD_inBuffer in;
+       ZSTD_outBuffer out;
+       ZSTD_frameParams params;
+       void *wksp = NULL;
+       size_t wksp_size;
+       ZSTD_DStream *dstream;
+       int err = 0;
+       size_t ret;
+
+       struct page *page;
+
+       fw_priv->is_paged_buf = true;
+       fw_priv->size = 0;
+
+       /* Set the first non-empty input buffer. */
+       in.src = in_buffer;
+       in.pos = 0;
+       in.size = in_size;
+
+       /*
+        * We need to know the window size to allocate the ZSTD_DStream.
+        * Since we are streaming, we need to allocate a buffer for the sliding
+        * window. The window size varies from 1 KB to ZSTD_WINDOWSIZE_MAX
+        * (8 MB), so it is important to use the actual value so as not to
+        * waste memory when it is smaller.
+        */
+       ret = ZSTD_getFrameParams(&params, in.src, in.size);
+       err = handle_zstd_error(ret);
+       if (err)
+               goto out;
+       if (ret != 0) {
+               printk("ZSTD-compressed data has an incomplete frame header");
+               err = -1;
+               goto out;
+       }
+       if (params.windowSize > (1 << ZSTD_WINDOWLOG_MAX)) {
+               printk("ZSTD-compressed data has too large a window size");
+               err = -1;
+               goto out;
+       }
+
+       /*
+        * Allocate the ZSTD_DStream now that we know how much memory is
+        * required.
+        */
+       wksp_size = ZSTD_DStreamWorkspaceBound(params.windowSize);
+       wksp = large_malloc(wksp_size);
+       dstream = ZSTD_initDStream(params.windowSize, wksp, wksp_size);
+       if (dstream == NULL) {
+               printk("Out of memory while allocating ZSTD_DStream");
+               err = -1;
+               goto out;
+       }
+
+       /*
+        * Decompression loop:
+        * Read more data if necessary (error if no more data can be read).
+        * Call the decompression function, which returns 0 when finished.
+        * Flush any data produced if using flush().
+        */
+       do {
+               /* If we need to reload data the input is truncated. */
+               if (in.pos == in.size) {
+                       printk("ZSTD-compressed data is truncated");
+                       err = -1;
+                       goto out;
+               }
+
+               /* Allocate the output buffer */
+               if (fw_grow_paged_buf(fw_priv, fw_priv->nr_pages + 1)) {
+                       err = -ENOMEM;
+                       goto out;
+               }
+
+               /* Decompress into the newly allocated page */
+               page = fw_priv->pages[fw_priv->nr_pages - 1];
+               out.dst = kmap(page);
+               out.pos = 0;
+               out.size = PAGE_SIZE;
+
+               /* Returns zero when the frame is complete. */
+               ret = ZSTD_decompressStream(dstream, &out, &in);
+               kunmap(page);
+               err = handle_zstd_error(ret);
+               if (err)
+                       goto out;
+               fw_priv->size += out.pos;
+       } while (ret != 0);
+
+       err = fw_map_paged_buf(fw_priv);
+out:
+       if (wksp != NULL)
+               large_free(wksp);
+       return err;
+}
+
+static int fw_decompress_zstd(struct device *dev, struct fw_priv *fw_priv,
+                           size_t in_size, const void *in_buffer)
+{
+       /* if the buffer is pre-allocated, we can perform in single-shot mode */
+       if (fw_priv->data)
+               return fw_decompress_zstd_single(dev, fw_priv, in_size, 
in_buffer);
+       else
+               return fw_decompress_zstd_pages(dev, fw_priv, in_size, 
in_buffer);
+}
+#endif /* CONFIG_FW_LOADER_COMPRESS_ZSTD */
+
 /* direct firmware loading support */
 static char fw_path_para[256];
 static const char * const fw_path[] = {
@@ -773,6 +954,11 @@ _request_firmware(const struct firmware **firmware_p, 
const char *name,
                ret = fw_get_filesystem_firmware(device, fw->priv, ".xz",
                                                 fw_decompress_xz);
 #endif
+#ifdef CONFIG_FW_LOADER_COMPRESS_ZSTD
+       if (ret == -ENOENT)
+               ret = fw_get_filesystem_firmware(device, fw->priv, ".zst",
+                                                fw_decompress_zstd);
+#endif
 
        if (ret == -ENOENT)
                ret = firmware_fallback_platform(fw->priv, opt_flags);


-- 
  René Rebe, ExactCODE GmbH, Lietzenburger Str. 42, DE-10789 Berlin
  https://exactcode.com | https://t2sde.org | https://rene.rebe.de

Reply via email to