This is an automated email from the ASF dual-hosted git repository. ggregory pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/commons-compress.git
The following commit(s) were added to refs/heads/master by this push: new ab8316c57 [COMPRESS-648] Add ability to restrict autodetection in CompressorStreamFactory (#433) ab8316c57 is described below commit ab8316c57b0a1ce62b6c58c7d459132e6e735be9 Author: Yakov Shafranovich <yako...@users.noreply.github.com> AuthorDate: Fri Nov 10 15:36:28 2023 -0500 [COMPRESS-648] Add ability to restrict autodetection in CompressorStreamFactory (#433) * Changes for COMPRESS-648 * Added comment * refactored * Removed line breaks and changed one method to package-private * refactoring test * Removed unused import --------- Co-authored-by: Yakov Shafranovich <yako...@amazon.com> --- .../compressors/CompressorStreamFactory.java | 62 ++++++++-- .../compress/compressors/DetectCompressorTest.java | 131 ++++++++++++++++++--- 2 files changed, 164 insertions(+), 29 deletions(-) diff --git a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java index fc7c7fe5e..3f604cfcb 100644 --- a/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java +++ b/src/main/java/org/apache/commons/compress/compressors/CompressorStreamFactory.java @@ -224,10 +224,29 @@ public class CompressorStreamFactory implements CompressorStreamProvider { * @since 1.14 */ public static String detect(final InputStream inputStream) throws CompressorException { + final Set<String> defaultCompressorNamesForDetection = Sets.newHashSet(BZIP2, GZIP, PACK200, SNAPPY_FRAMED, Z, DEFLATE, XZ, LZMA, LZ4_FRAMED, ZSTANDARD); + return detect(inputStream, defaultCompressorNamesForDetection); + } + + /** + * Try to detect the type of compressor stream while limiting the type to the provided set of compressor names. + * + * @param inputStream input stream + * @param compressorNames compressor names to limit autodetection + * @return type of compressor stream detected + * @throws CompressorException if no compressor stream type was detected + * or if something else went wrong + * @throws IllegalArgumentException if stream is null or does not support mark + */ + static String detect(final InputStream inputStream, final Set<String> compressorNames) throws CompressorException { if (inputStream == null) { throw new IllegalArgumentException("Stream must not be null."); } + if (compressorNames == null || compressorNames.isEmpty()) { + throw new IllegalArgumentException("Compressor names cannot be null or empty"); + } + if (!inputStream.markSupported()) { throw new IllegalArgumentException("Mark is not supported."); } @@ -242,43 +261,44 @@ public class CompressorStreamFactory implements CompressorStreamProvider { throw new CompressorException("IOException while reading signature.", e); } - if (BZip2CompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(BZIP2) && BZip2CompressorInputStream.matches(signature, signatureLength)) { return BZIP2; } - if (GzipCompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(GZIP) && GzipCompressorInputStream.matches(signature, signatureLength)) { return GZIP; } - if (Pack200CompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(PACK200) && Pack200CompressorInputStream.matches(signature, signatureLength)) { return PACK200; } - if (FramedSnappyCompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(SNAPPY_FRAMED) && + FramedSnappyCompressorInputStream.matches(signature, signatureLength)) { return SNAPPY_FRAMED; } - if (ZCompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(Z) && ZCompressorInputStream.matches(signature, signatureLength)) { return Z; } - if (DeflateCompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(DEFLATE) && DeflateCompressorInputStream.matches(signature, signatureLength)) { return DEFLATE; } - if (XZUtils.matches(signature, signatureLength)) { + if (compressorNames.contains(XZ) && XZUtils.matches(signature, signatureLength)) { return XZ; } - if (LZMAUtils.matches(signature, signatureLength)) { + if (compressorNames.contains(LZMA) && LZMAUtils.matches(signature, signatureLength)) { return LZMA; } - if (FramedLZ4CompressorInputStream.matches(signature, signatureLength)) { + if (compressorNames.contains(LZ4_FRAMED) && FramedLZ4CompressorInputStream.matches(signature, signatureLength)) { return LZ4_FRAMED; } - if (ZstdUtils.matches(signature, signatureLength)) { + if (compressorNames.contains(ZSTANDARD) && ZstdUtils.matches(signature, signatureLength)) { return ZSTANDARD; } @@ -502,6 +522,7 @@ public class CompressorStreamFactory implements CompressorStreamProvider { this.decompressConcatenated = decompressUntilEOF; this.memoryLimitInKb = memoryLimitInKb; } + /** * Create a compressor input stream from an input stream, auto-detecting the * compressor type from the first few bytes of the stream. The InputStream @@ -520,6 +541,27 @@ public class CompressorStreamFactory implements CompressorStreamProvider { return createCompressorInputStream(detect(in), in); } + /** + * Create a compressor input stream from an input stream, auto-detecting the + * compressor type from the first few bytes of the stream while limiting the detected type + * to the provided set of compressor names. The InputStream must support marks, like BufferedInputStream. + * + * @param in + * the input stream + * @param compressorNames + * compressor names to limit autodetection + * @return the compressor input stream + * @throws CompressorException + * if the autodetected compressor is not in the provided set of compressor names + * @throws IllegalArgumentException + * if the stream is null or does not support mark + * @since 1.26 + */ + public CompressorInputStream createCompressorInputStream(final InputStream in, final Set<String> compressorNames) + throws CompressorException { + return createCompressorInputStream(detect(in, compressorNames), in); + } + /** * Creates a compressor input stream from a compressor name and an input * stream. diff --git a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java index 1c7e2ab81..c8970294a 100644 --- a/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java +++ b/src/test/java/org/apache/commons/compress/compressors/DetectCompressorTest.java @@ -30,6 +30,10 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.nio.file.Files; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Stream; import org.apache.commons.compress.MemoryLimitException; import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; @@ -41,6 +45,9 @@ import org.apache.commons.compress.compressors.zstandard.ZstdCompressorInputStre import org.apache.commons.compress.utils.ByteUtils; import org.apache.commons.io.input.BrokenInputStream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; @SuppressWarnings("deprecation") // deliberately tests setDecompressConcatenated public final class DetectCompressorTest { @@ -96,26 +103,25 @@ public final class DetectCompressorTest { }; @SuppressWarnings("resource") // Caller closes. - private CompressorInputStream createStreamFor(final String resource) - throws CompressorException, IOException { + private CompressorInputStream createStreamFor(final String resource) throws CompressorException, IOException { return factory.createCompressorInputStream( new BufferedInputStream(Files.newInputStream( getFile(resource).toPath()))); } @SuppressWarnings("resource") // Caller closes. - private CompressorInputStream createStreamFor(final String resource, final CompressorStreamFactory factory) - throws CompressorException, IOException { - return factory.createCompressorInputStream( - new BufferedInputStream(Files.newInputStream( - getFile(resource).toPath()))); + private CompressorInputStream createStreamFor(final String resource, final Set<String> compressorNames) throws CompressorException, IOException { + return factory.createCompressorInputStream(new BufferedInputStream(Files.newInputStream(getFile(resource).toPath())), compressorNames); + } + + @SuppressWarnings("resource") // Caller closes. + private CompressorInputStream createStreamFor(final String resource, final CompressorStreamFactory factory) throws CompressorException, IOException { + return factory.createCompressorInputStream(new BufferedInputStream(Files.newInputStream(getFile(resource).toPath()))); } private InputStream createStreamFor(final String fileName, final int memoryLimitInKb) throws Exception { - final CompressorStreamFactory fac = new CompressorStreamFactory(true, - memoryLimitInKb); - final InputStream is = new BufferedInputStream( - Files.newInputStream(getFile(fileName).toPath())); + final CompressorStreamFactory fac = new CompressorStreamFactory(true, memoryLimitInKb); + final InputStream is = new BufferedInputStream(Files.newInputStream(getFile(fileName).toPath())); try { return fac.createCompressorInputStream(is); } catch (final CompressorException e) { @@ -129,9 +135,12 @@ public final class DetectCompressorTest { } private String detect(final String testFileName) throws IOException, CompressorException { - try (InputStream is = new BufferedInputStream( - Files.newInputStream(getFile(testFileName).toPath()))) { - return CompressorStreamFactory.detect(is); + return detect(testFileName, null); + } + + private String detect(final String testFileName, final Set<String> compressorNames) throws IOException, CompressorException { + try (InputStream is = new BufferedInputStream(Files.newInputStream(getFile(testFileName).toPath()))) { + return compressorNames != null ? CompressorStreamFactory.detect(is, compressorNames) : CompressorStreamFactory.detect(is); } } @@ -154,17 +163,58 @@ public final class DetectCompressorTest { assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY)))); - final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(null), - "shouldn't be able to detect null stream"); + final IllegalArgumentException e = assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(null), "shouldn't be able to detect null stream"); assertEquals("Stream must not be null.", e.getMessage()); - final CompressorException ce = assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new BrokenInputStream())), - "Expected IOException"); + final CompressorException ce = assertThrows(CompressorException.class, () -> CompressorStreamFactory.detect(new BufferedInputStream(new BrokenInputStream())), "Expected IOException"); assertEquals("IOException while reading signature.", ce.getMessage()); } @Test - public void testDetection() throws Exception { + public void testDetectNullOrEmptyCompressorNames() throws Exception { + assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(createStreamFor("bla.txt.bz2"), (Set<String>) null)); + assertThrows(IllegalArgumentException.class, () -> CompressorStreamFactory.detect(createStreamFor("bla.tgz"), new HashSet<>())); + } + + public static Stream<Arguments> limitedByNameData() { + return Stream.of( + Arguments.of("bla.txt.bz2", CompressorStreamFactory.BZIP2), + Arguments.of("bla.tgz", CompressorStreamFactory.GZIP), + Arguments.of("bla.pack", CompressorStreamFactory.PACK200), + Arguments.of("bla.tar.xz", CompressorStreamFactory.XZ), + Arguments.of("bla.tar.deflatez", CompressorStreamFactory.DEFLATE), + Arguments.of("bla.tar.lz4", CompressorStreamFactory.LZ4_FRAMED), + Arguments.of("bla.tar.lzma", CompressorStreamFactory.LZMA), + Arguments.of("bla.tar.sz", CompressorStreamFactory.SNAPPY_FRAMED), + Arguments.of("bla.tar.Z", CompressorStreamFactory.Z), + Arguments.of("bla.tar.zst", CompressorStreamFactory.ZSTANDARD) + ); + } + + @ParameterizedTest + @MethodSource("limitedByNameData") + public void testDetectLimitedByName(final String filename, final String compressorName) throws Exception { + assertEquals(compressorName, detect(filename, Collections.singleton(compressorName))); + } + + @Test + public void testDetectLimitedByNameNotFound() throws Exception { + Set<String> compressorNames = Collections.singleton(CompressorStreamFactory.DEFLATE); + + assertThrows(CompressorException.class, () -> detect("bla.txt.bz2", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tgz", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.pack", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.xz", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.BZIP2))); + assertThrows(CompressorException.class, () -> detect("bla.tar.lz4", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.lzma", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.sz", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.Z", compressorNames)); + assertThrows(CompressorException.class, () -> detect("bla.tar.zst", compressorNames)); + } + + @Test + public void testCreateWithAutoDetection() throws Exception { try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2")) { assertNotNull(bzip2); assertTrue(bzip2 instanceof BZip2CompressorInputStream); @@ -198,6 +248,49 @@ public final class DetectCompressorTest { assertThrows(CompressorException.class, () -> factory.createCompressorInputStream(new ByteArrayInputStream(ByteUtils.EMPTY_BYTE_ARRAY))); } + @Test + public void testCreateLimitedByName() throws Exception { + try (CompressorInputStream bzip2 = createStreamFor("bla.txt.bz2", Collections.singleton(CompressorStreamFactory.BZIP2))) { + assertNotNull(bzip2); + assertTrue(bzip2 instanceof BZip2CompressorInputStream); + } + + try (CompressorInputStream gzip = createStreamFor("bla.tgz", Collections.singleton(CompressorStreamFactory.GZIP))) { + assertNotNull(gzip); + assertTrue(gzip instanceof GzipCompressorInputStream); + } + + try (CompressorInputStream pack200 = createStreamFor("bla.pack", Collections.singleton(CompressorStreamFactory.PACK200))) { + assertNotNull(pack200); + assertTrue(pack200 instanceof Pack200CompressorInputStream); + } + + try (CompressorInputStream xz = createStreamFor("bla.tar.xz", Collections.singleton(CompressorStreamFactory.XZ))) { + assertNotNull(xz); + assertTrue(xz instanceof XZCompressorInputStream); + } + + try (CompressorInputStream zlib = createStreamFor("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.DEFLATE))) { + assertNotNull(zlib); + assertTrue(zlib instanceof DeflateCompressorInputStream); + } + + try (CompressorInputStream zstd = createStreamFor("bla.tar.zst", Collections.singleton(CompressorStreamFactory.ZSTANDARD))) { + assertNotNull(zstd); + assertTrue(zstd instanceof ZstdCompressorInputStream); + } + } + + @Test + public void testCreateLimitedByNameNotFound() throws Exception { + assertThrows(CompressorException.class, () -> createStreamFor("bla.txt.bz2", Collections.singleton(CompressorStreamFactory.BROTLI))); + assertThrows(CompressorException.class, () -> createStreamFor("bla.tgz", Collections.singleton(CompressorStreamFactory.Z))); + assertThrows(CompressorException.class, () -> createStreamFor("bla.pack", Collections.singleton(CompressorStreamFactory.SNAPPY_FRAMED))); + assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.xz", Collections.singleton(CompressorStreamFactory.GZIP))); + assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.deflatez", Collections.singleton(CompressorStreamFactory.PACK200))); + assertThrows(CompressorException.class, () -> createStreamFor("bla.tar.zst", Collections.singleton(CompressorStreamFactory.LZ4_FRAMED))); + } + @Test public void testLZMAMemoryLimit() throws Exception { assertThrows(MemoryLimitException.class, () -> createStreamFor("COMPRESS-382", 100));