This is an automated email from the ASF dual-hosted git repository.

ggregory pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-compress.git


The following commit(s) were added to refs/heads/master by this push:
     new ab8316c57 [COMPRESS-648] Add ability to restrict autodetection in 
CompressorStreamFactory (#433)
ab8316c57 is described below

commit ab8316c57b0a1ce62b6c58c7d459132e6e735be9
Author: Yakov Shafranovich <yako...@users.noreply.github.com>
AuthorDate: Fri Nov 10 15:36:28 2023 -0500

    [COMPRESS-648] Add ability to restrict autodetection in 
CompressorStreamFactory (#433)
    
    * Changes for COMPRESS-648
    
    * Added comment
    
    * refactored
    
    * Removed line breaks and changed one method to package-private
    
    * refactoring test
    
    * Removed unused import
    
    ---------
    
    Co-authored-by: Yakov Shafranovich <yako...@amazon.com>
---
 .../compressors/CompressorStreamFactory.java       |  62 ++++++++--
 .../compress/compressors/DetectCompressorTest.java | 131 ++++++++++++++++++---
 2 files changed, 164 insertions(+), 29 deletions(-)

diff --git 
a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
 
b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
index fc7c7fe5e..3f604cfcb 100644
--- 
a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
+++ 
b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java
@@ -224,10 +224,29 @@ public class CompressorStreamFactory implements 
CompressorStreamProvider {
      * @since 1.14
      */
     public static String detect(final InputStream inputStream) throws 
CompressorException {
+        final Set<String> defaultCompressorNamesForDetection = 
Sets.newHashSet(BZIP2, GZIP, PACK200, SNAPPY_FRAMED, Z, DEFLATE, XZ, LZMA, 
LZ4_FRAMED, ZSTANDARD);
+        return detect(inputStream, defaultCompressorNamesForDetection);
+    }
+
+    /**
+     * Try to detect the type of compressor stream while limiting the type to 
the provided set of compressor names.
+     *
+     * @param inputStream input stream
+     * @param compressorNames compressor names to limit autodetection
+     * @return type of compressor stream detected
+     * @throws CompressorException if no compressor stream type was detected
+     *                             or if something else went wrong
+     * @throws IllegalArgumentException if stream is null or does not support 
mark
+     */
+    static String detect(final InputStream inputStream, final Set<String> 
compressorNames) throws CompressorException {
         if (inputStream == null) {
             throw new IllegalArgumentException("Stream must not be null.");
         }
 
+        if (compressorNames == null || compressorNames.isEmpty()) {
+            throw new IllegalArgumentException("Compressor names cannot be 
null or empty");
+        }
+
         if (!inputStream.markSupported()) {
             throw new IllegalArgumentException("Mark is not supported.");
         }
@@ -242,43 +261,44 @@ public class CompressorStreamFactory implements 
CompressorStreamProvider {
             throw new CompressorException("IOException while reading 
signature.", e);
         }
 
-        if (BZip2CompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(BZIP2) && 
BZip2CompressorInputStream.matches(signature, signatureLength)) {
             return BZIP2;
         }
 
-        if (GzipCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(GZIP) && 
GzipCompressorInputStream.matches(signature, signatureLength)) {
             return GZIP;
         }
 
-        if (Pack200CompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(PACK200) && 
Pack200CompressorInputStream.matches(signature, signatureLength)) {
             return PACK200;
         }
 
-        if (FramedSnappyCompressorInputStream.matches(signature, 
signatureLength)) {
+        if (compressorNames.contains(SNAPPY_FRAMED) &&
+                FramedSnappyCompressorInputStream.matches(signature, 
signatureLength)) {
             return SNAPPY_FRAMED;
         }
 
-        if (ZCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(Z) && 
ZCompressorInputStream.matches(signature, signatureLength)) {
             return Z;
         }
 
-        if (DeflateCompressorInputStream.matches(signature, signatureLength)) {
+        if (compressorNames.contains(DEFLATE) && 
DeflateCompressorInputStream.matches(signature, signatureLength)) {
             return DEFLATE;
         }
 
-        if (XZUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(XZ) && XZUtils.matches(signature, 
signatureLength)) {
             return XZ;
         }
 
-        if (LZMAUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(LZMA) && LZMAUtils.matches(signature, 
signatureLength)) {
             return LZMA;
         }
 
-        if (FramedLZ4CompressorInputStream.matches(signature, 
signatureLength)) {
+        if (compressorNames.contains(LZ4_FRAMED) && 
FramedLZ4CompressorInputStream.matches(signature, signatureLength)) {
             return LZ4_FRAMED;
         }
 
-        if (ZstdUtils.matches(signature, signatureLength)) {
+        if (compressorNames.contains(ZSTANDARD) && 
ZstdUtils.matches(signature, signatureLength)) {
             return ZSTANDARD;
         }
 
@@ -502,6 +522,7 @@ public class CompressorStreamFactory implements 
CompressorStreamProvider {
         this.decompressConcatenated = decompressUntilEOF;
         this.memoryLimitInKb = memoryLimitInKb;
     }
+
     /**
      * Create a compressor input stream from an input stream, auto-detecting 
the
      * compressor type from the first few bytes of the stream. The InputStream
@@ -520,6 +541,27 @@ public class CompressorStreamFactory implements 
CompressorStreamProvider {
         return createCompressorInputStream(detect(in), in);
     }
 
+    /**
+     * Create a compressor input stream from an input stream, auto-detecting 
the
+     * compressor type from the first few bytes of the stream while limiting 
the detected type
+     * to the provided set of compressor names. The InputStream must support 
marks, like BufferedInputStream.
+     *
+     * @param in
+     *            the input stream
+     * @param compressorNames
+     *            compressor names to limit autodetection
+     * @return the compressor input stream
+     * @throws CompressorException
+     *             if the autodetected compressor is not in the provided set 
of compressor names
+     * @throws IllegalArgumentException
+     *             if the stream is null or does not support mark
+     * @since 1.26
+     */
+    public CompressorInputStream createCompressorInputStream(final InputStream 
in, final Set<String> compressorNames)
+            throws CompressorException {
+        return createCompressorInputStream(detect(in, compressorNames), in);
+    }
+
     /**
      * Creates a compressor input stream from a compressor name and an input
      * stream.
diff --git 
a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
 
b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
index 1c7e2ab81..c8970294a 100644
--- 
a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
+++ 
b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java
@@ -30,6 +30,10 @@ import java.io.ByteArrayInputStream;
 import java.io.IOException;
 import java.io.InputStream;
 import java.nio.file.Files;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.stream.Stream;
 
 import org.apache.commons.compress.MemoryLimitException;
 import 
org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream;
@@ -41,6 +45,9 @@ import 
org.apache.commons.compress.compressors.zstandard.ZstdCompressorInputStre
 import org.apache.commons.compress.utils.ByteUtils;
 import org.apache.commons.io.input.BrokenInputStream;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 
 @SuppressWarnings("deprecation") // deliberately tests 
setDecompressConcatenated
 public final class DetectCompressorTest {
@@ -96,26 +103,25 @@ public final class DetectCompressorTest {
     };
 
     @SuppressWarnings("resource") // Caller closes.
-    private CompressorInputStream createStreamFor(final String resource)
-            throws CompressorException, IOException {
+    private CompressorInputStream createStreamFor(final String resource) 
throws CompressorException, IOException {
         return factory.createCompressorInputStream(
                    new BufferedInputStream(Files.newInputStream(
                        getFile(resource).toPath())));
     }
 
     @SuppressWarnings("resource") // Caller closes.
-    private CompressorInputStream createStreamFor(final String resource, final 
CompressorStreamFactory factory)
-            throws CompressorException, IOException {
-        return factory.createCompressorInputStream(
-                   new BufferedInputStream(Files.newInputStream(
-                       getFile(resource).toPath())));
+    private CompressorInputStream createStreamFor(final String resource, final 
Set<String> compressorNames) throws CompressorException, IOException {
+        return factory.createCompressorInputStream(new 
BufferedInputStream(Files.newInputStream(getFile(resource).toPath())), 
compressorNames);
+    }
+
+    @SuppressWarnings("resource") // Caller closes.
+    private CompressorInputStream createStreamFor(final String resource, final 
CompressorStreamFactory factory) throws CompressorException, IOException {
+        return factory.createCompressorInputStream(new 
BufferedInputStream(Files.newInputStream(getFile(resource).toPath())));
     }
 
     private InputStream createStreamFor(final String fileName, final int 
memoryLimitInKb) throws Exception {
-        final CompressorStreamFactory fac = new CompressorStreamFactory(true,
-                memoryLimitInKb);
-        final InputStream is = new BufferedInputStream(
-                Files.newInputStream(getFile(fileName).toPath()));
+        final CompressorStreamFactory fac = new CompressorStreamFactory(true, 
memoryLimitInKb);
+        final InputStream is = new 
BufferedInputStream(Files.newInputStream(getFile(fileName).toPath()));
         try {
             return fac.createCompressorInputStream(is);
         } catch (final CompressorException e) {
@@ -129,9 +135,12 @@ public final class DetectCompressorTest {
     }
 
     private String detect(final String testFileName) throws IOException, 
CompressorException {
-        try (InputStream is = new BufferedInputStream(
-                Files.newInputStream(getFile(testFileName).toPath()))) {
-            return CompressorStreamFactory.detect(is);
+        return detect(testFileName, null);
+    }
+
+    private String detect(final String testFileName, final Set<String> 
compressorNames) throws IOException, CompressorException {
+        try (InputStream is = new 
BufferedInputStream(Files.newInputStream(getFile(testFileName).toPath()))) {
+            return compressorNames != null ? 
CompressorStreamFactory.detect(is, compressorNames) : 
CompressorStreamFactory.detect(is);
         }
     }
 
@@ -154,17 +163,58 @@ public final class DetectCompressorTest {
 
         assertThrows(CompressorException.class, () -> 
CompressorStreamFactory.detect(new BufferedInputStream(new 
ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY))));
 
-        final IllegalArgumentException e = 
assertThrows(IllegalArgumentException.class, () -> 
CompressorStreamFactory.detect(null),
-                "shouldn't be able to detect null stream");
+        final IllegalArgumentException e = 
assertThrows(IllegalArgumentException.class, () -> 
CompressorStreamFactory.detect(null), "shouldn't be able to detect null 
stream");
         assertEquals("Stream must not be null.", e.getMessage());
 
-        final CompressorException ce = assertThrows(CompressorException.class, 
() -> CompressorStreamFactory.detect(new BufferedInputStream(new 
BrokenInputStream())),
-                "Expected IOException");
+        final CompressorException ce = assertThrows(CompressorException.class, 
() -> CompressorStreamFactory.detect(new BufferedInputStream(new 
BrokenInputStream())), "Expected IOException");
         assertEquals("IOException while reading signature.", ce.getMessage());
     }
 
     @Test
-    public void testDetection() throws Exception {
+    public void testDetectNullOrEmptyCompressorNames() throws Exception {
+        assertThrows(IllegalArgumentException.class, () -> 
CompressorStreamFactory.detect(createStreamFor("bla.txt.bz2"), (Set<String>) 
null));
+        assertThrows(IllegalArgumentException.class, () -> 
CompressorStreamFactory.detect(createStreamFor("bla.tgz"), new HashSet<>()));
+    }
+
+    public static Stream<Arguments> limitedByNameData() {
+        return Stream.of(
+                Arguments.of("bla.txt.bz2", CompressorStreamFactory.BZIP2),
+                Arguments.of("bla.tgz", CompressorStreamFactory.GZIP),
+                Arguments.of("bla.pack", CompressorStreamFactory.PACK200),
+                Arguments.of("bla.tar.xz", CompressorStreamFactory.XZ),
+                Arguments.of("bla.tar.deflatez", 
CompressorStreamFactory.DEFLATE),
+                Arguments.of("bla.tar.lz4", 
CompressorStreamFactory.LZ4_FRAMED),
+                Arguments.of("bla.tar.lzma", CompressorStreamFactory.LZMA),
+                Arguments.of("bla.tar.sz", 
CompressorStreamFactory.SNAPPY_FRAMED),
+                Arguments.of("bla.tar.Z", CompressorStreamFactory.Z),
+                Arguments.of("bla.tar.zst", CompressorStreamFactory.ZSTANDARD)
+        );
+    }
+
+    @ParameterizedTest
+    @MethodSource("limitedByNameData")
+    public void testDetectLimitedByName(final String filename, final String 
compressorName) throws Exception {
+        assertEquals(compressorName, detect(filename, 
Collections.singleton(compressorName)));
+    }
+
+    @Test
+    public void testDetectLimitedByNameNotFound() throws Exception {
+        Set<String> compressorNames = 
Collections.singleton(CompressorStreamFactory.DEFLATE);
+
+        assertThrows(CompressorException.class, () -> detect("bla.txt.bz2", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tgz", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.pack", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.xz", 
compressorNames));
+        assertThrows(CompressorException.class, () -> 
detect("bla.tar.deflatez", 
Collections.singleton(CompressorStreamFactory.BZIP2)));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.lz4", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.lzma", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.sz", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.Z", 
compressorNames));
+        assertThrows(CompressorException.class, () -> detect("bla.tar.zst", 
compressorNames));
+    }
+
+    @Test
+    public void testCreateWithAutoDetection() throws Exception {
         try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2")) {
             assertNotNull(bzip2);
             assertTrue(bzip2 instanceof BZip2CompressorInputStream);
@@ -198,6 +248,49 @@ public final class DetectCompressorTest {
         assertThrows(CompressorException.class, () -> 
factory.createCompressorInputStream(new 
ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY)));
     }
 
+    @Test
+    public void testCreateLimitedByName() throws Exception {
+        try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2", 
Collections.singleton(CompressorStreamFactory.BZIP2))) {
+            assertNotNull(bzip2);
+            assertTrue(bzip2 instanceof BZip2CompressorInputStream);
+        }
+
+        try (CompressorInputStream gzip = createStreamFor("bla.tgz", 
Collections.singleton(CompressorStreamFactory.GZIP))) {
+            assertNotNull(gzip);
+            assertTrue(gzip instanceof GzipCompressorInputStream);
+        }
+
+        try (CompressorInputStream pack200 = createStreamFor("bla.pack", 
Collections.singleton(CompressorStreamFactory.PACK200))) {
+            assertNotNull(pack200);
+            assertTrue(pack200 instanceof Pack200CompressorInputStream);
+        }
+
+        try (CompressorInputStream xz = createStreamFor("bla.tar.xz", 
Collections.singleton(CompressorStreamFactory.XZ))) {
+            assertNotNull(xz);
+            assertTrue(xz instanceof XZCompressorInputStream);
+        }
+
+        try (CompressorInputStream zlib = createStreamFor("bla.tar.deflatez", 
Collections.singleton(CompressorStreamFactory.DEFLATE))) {
+            assertNotNull(zlib);
+            assertTrue(zlib instanceof DeflateCompressorInputStream);
+        }
+
+        try (CompressorInputStream zstd = createStreamFor("bla.tar.zst", 
Collections.singleton(CompressorStreamFactory.ZSTANDARD))) {
+            assertNotNull(zstd);
+            assertTrue(zstd instanceof ZstdCompressorInputStream);
+        }
+    }
+
+    @Test
+    public void testCreateLimitedByNameNotFound() throws Exception {
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.txt.bz2", 
Collections.singleton(CompressorStreamFactory.BROTLI)));
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.tgz", Collections.singleton(CompressorStreamFactory.Z)));
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.pack", 
Collections.singleton(CompressorStreamFactory.SNAPPY_FRAMED)));
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.tar.xz", 
Collections.singleton(CompressorStreamFactory.GZIP)));
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.tar.deflatez", 
Collections.singleton(CompressorStreamFactory.PACK200)));
+        assertThrows(CompressorException.class, () -> 
createStreamFor("bla.tar.zst", 
Collections.singleton(CompressorStreamFactory.LZ4_FRAMED)));
+    }
+
     @Test
     public void testLZMAMemoryLimit() throws Exception {
         assertThrows(MemoryLimitException.class, () -> 
createStreamFor("COMPRESS-382", 100));

Reply via email to