This is an automated email from the ASF dual-hosted git repository. gangwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/parquet-java.git
The following commit(s) were added to refs/heads/master by this push: new dab5aae88 PARQUET-34: Add #contains FilterPredicate for Array columns (#1328) dab5aae88 is described below commit dab5aae880ff27734351c9bb576fe503a6d6bf4f Author: Claire McGinty <clai...@spotify.com> AuthorDate: Tue Jun 4 10:34:35 2024 -0400 PARQUET-34: Add #contains FilterPredicate for Array columns (#1328) --- .../parquet/filter2/compat/FilterCompat.java | 8 +- ...lInverseRewriter.java => ContainsRewriter.java} | 95 ++++++-- .../parquet/filter2/predicate/FilterApi.java | 5 + .../parquet/filter2/predicate/FilterPredicate.java | 5 + .../filter2/predicate/LogicalInverseRewriter.java | 6 + .../parquet/filter2/predicate/LogicalInverter.java | 6 + .../parquet/filter2/predicate/Operators.java | 139 ++++++++++- .../predicate/SchemaCompatibilityValidator.java | 21 +- .../IncrementallyUpdatedFilterPredicate.java | 2 +- .../column/columnindex/ColumnIndexBuilder.java | 6 + .../internal/column/columnindex/IndexIterator.java | 113 +++++++++ .../filter2/columnindex/ColumnIndexFilter.java | 6 + .../filter2/predicate/TestContainsRewriter.java | 89 ++++++++ .../filter2/predicate/TestFilterApiMethods.java | 28 +++ .../TestSchemaCompatibilityValidator.java | 26 ++- .../column/columnindex/TestColumnIndexBuilder.java | 57 +++++ .../column/columnindex/TestIndexIterator.java | 91 ++++++++ .../filter2/columnindex/TestColumnIndexFilter.java | 70 +++++- ...crementallyUpdatedFilterPredicateGenerator.java | 254 +++++++++++++++++++-- .../filter2/bloomfilterlevel/BloomFilterImpl.java | 5 + .../filter2/dictionarylevel/DictionaryFilter.java | 6 + .../filter2/statisticslevel/StatisticsFilter.java | 6 + .../dictionarylevel/DictionaryFilterTest.java | 56 ++++- .../filter2/recordlevel/PhoneBookWriter.java | 56 ++++- .../recordlevel/TestRecordLevelFilters.java | 107 ++++++++- .../statisticslevel/TestStatisticsFilter.java | 34 +++ .../apache/parquet/hadoop/TestBloomFiltering.java | 57 ++++- .../parquet/hadoop/TestParquetWriterError.java | 9 +- 28 files changed, 1303 insertions(+), 60 deletions(-) diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/compat/FilterCompat.java b/parquet-column/src/main/java/org/apache/parquet/filter2/compat/FilterCompat.java index 76d51af89..7e265bbb0 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/compat/FilterCompat.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/compat/FilterCompat.java @@ -22,6 +22,7 @@ import static org.apache.parquet.Preconditions.checkArgument; import java.util.Objects; import org.apache.parquet.filter.UnboundRecordFilter; +import org.apache.parquet.filter2.predicate.ContainsRewriter; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.LogicalInverseRewriter; import org.slf4j.Logger; @@ -82,7 +83,12 @@ public class FilterCompat { LOG.info("Predicate has been collapsed to: {}", collapsedPredicate); } - return new FilterPredicateCompat(collapsedPredicate); + FilterPredicate rewrittenContainsPredicate = ContainsRewriter.rewrite(collapsedPredicate); + if (!collapsedPredicate.equals(rewrittenContainsPredicate)) { + LOG.info("Contains() Predicate has been rewritten to: {}", rewrittenContainsPredicate); + } + + return new FilterPredicateCompat(rewrittenContainsPredicate); } /** diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/ContainsRewriter.java similarity index 54% copy from parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java copy to parquet-column/src/main/java/org/apache/parquet/filter2/predicate/ContainsRewriter.java index 175f9b4b7..ea2d70e8e 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/ContainsRewriter.java @@ -18,12 +18,10 @@ */ package org.apache.parquet.filter2.predicate; -import static org.apache.parquet.filter2.predicate.FilterApi.and; -import static org.apache.parquet.filter2.predicate.FilterApi.or; - import java.util.Objects; import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor; import org.apache.parquet.filter2.predicate.Operators.And; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -38,25 +36,21 @@ import org.apache.parquet.filter2.predicate.Operators.Or; import org.apache.parquet.filter2.predicate.Operators.UserDefined; /** - * Recursively removes all use of the not() operator in a predicate - * by replacing all instances of not(x) with the inverse(x), - * eg: not(and(eq(), not(eq(y))) -> or(notEq(), eq(y)) - * <p> - * The returned predicate should have the same meaning as the original, but - * without the use of the not() operator. - * <p> - * See also {@link LogicalInverter}, which is used - * to do the inversion. + * Recursively rewrites Contains predicates composed using And or Or into a single Contains predicate + * containing all predicate assertions. + * + * This is a performance optimization, as all composed Contains sub-predicates must share the same column, and + * can therefore be applied efficiently as a single predicate pass. */ -public final class LogicalInverseRewriter implements Visitor<FilterPredicate> { - private static final LogicalInverseRewriter INSTANCE = new LogicalInverseRewriter(); +public final class ContainsRewriter implements Visitor<FilterPredicate> { + private static final ContainsRewriter INSTANCE = new ContainsRewriter(); public static FilterPredicate rewrite(FilterPredicate pred) { Objects.requireNonNull(pred, "pred cannot be null"); return pred.accept(INSTANCE); } - private LogicalInverseRewriter() {} + private ContainsRewriter() {} @Override public <T extends Comparable<T>> FilterPredicate visit(Eq<T> eq) { @@ -98,19 +92,84 @@ public final class LogicalInverseRewriter implements Visitor<FilterPredicate> { return notIn; } + @Override + public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) { + return contains; + } + @Override public FilterPredicate visit(And and) { - return and(and.getLeft().accept(this), and.getRight().accept(this)); + final FilterPredicate left; + if (and.getLeft() instanceof And) { + left = visit((And) and.getLeft()); + } else if (and.getLeft() instanceof Or) { + left = visit((Or) and.getLeft()); + } else if (and.getLeft() instanceof Contains) { + left = and.getLeft(); + } else { + return and; + } + + final FilterPredicate right; + if (and.getRight() instanceof And) { + right = visit((And) and.getRight()); + } else if (and.getRight() instanceof Or) { + right = visit((Or) and.getRight()); + } else if (and.getRight() instanceof Contains) { + right = and.getRight(); + } else { + return and; + } + + if (left instanceof Contains) { + if (!(right instanceof Contains)) { + throw new UnsupportedOperationException( + "Contains predicates cannot be composed with non-Contains predicates"); + } + return ((Contains) left).and(right); + } else { + return and; + } } @Override public FilterPredicate visit(Or or) { - return or(or.getLeft().accept(this), or.getRight().accept(this)); + final FilterPredicate left; + if (or.getLeft() instanceof And) { + left = visit((And) or.getLeft()); + } else if (or.getLeft() instanceof Or) { + left = visit((Or) or.getLeft()); + } else if (or.getLeft() instanceof Contains) { + left = or.getLeft(); + } else { + return or; + } + + final FilterPredicate right; + if (or.getRight() instanceof And) { + right = visit((And) or.getRight()); + } else if (or.getRight() instanceof Or) { + right = visit((Or) or.getRight()); + } else if (or.getRight() instanceof Contains) { + right = or.getRight(); + } else { + return or; + } + + if (left instanceof Contains) { + if (!(right instanceof Contains)) { + throw new UnsupportedOperationException( + "Contains predicates cannot be composed with non-Contains predicates"); + } + return ((Contains) left).or(right); + } else { + return or; + } } @Override public FilterPredicate visit(Not not) { - return LogicalInverter.invert(not.getPredicate().accept(this)); + throw new IllegalStateException("Not predicate should be rewritten before being evaluated by ContainsRewriter"); } @Override diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java index 841d68f2c..4126b73e5 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterApi.java @@ -24,6 +24,7 @@ import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.BinaryColumn; import org.apache.parquet.filter2.predicate.Operators.BooleanColumn; import org.apache.parquet.filter2.predicate.Operators.Column; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.DoubleColumn; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.FloatColumn; @@ -257,6 +258,10 @@ public final class FilterApi { return new NotIn<>(column, values); } + public static <T extends Comparable<T>> Contains<T> contains(Eq<T> pred) { + return Contains.of(pred); + } + /** * Keeps records that pass the provided {@link UserDefinedPredicate} * <p> diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java index 2f4c534d9..a662bb0b1 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/FilterPredicate.java @@ -19,6 +19,7 @@ package org.apache.parquet.filter2.predicate; import org.apache.parquet.filter2.predicate.Operators.And; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -84,6 +85,10 @@ public interface FilterPredicate { throw new UnsupportedOperationException("visit NotIn is not supported."); } + default <T extends Comparable<T>> R visit(Contains<T> contains) { + throw new UnsupportedOperationException("visit Contains is not supported."); + } + R visit(And and); R visit(Or or); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java index 175f9b4b7..d1d7f07e8 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverseRewriter.java @@ -24,6 +24,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.or; import java.util.Objects; import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor; import org.apache.parquet.filter2.predicate.Operators.And; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -98,6 +99,11 @@ public final class LogicalInverseRewriter implements Visitor<FilterPredicate> { return notIn; } + @Override + public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) { + return contains; + } + @Override public FilterPredicate visit(And and) { return and(and.getLeft().accept(this), and.getRight().accept(this)); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java index 93ebb34b7..d1d006ccf 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/LogicalInverter.java @@ -21,6 +21,7 @@ package org.apache.parquet.filter2.predicate; import java.util.Objects; import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor; import org.apache.parquet.filter2.predicate.Operators.And; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -92,6 +93,11 @@ public final class LogicalInverter implements Visitor<FilterPredicate> { return new In<>(notIn.getColumn(), notIn.getValues()); } + @Override + public <T extends Comparable<T>> FilterPredicate visit(Contains<T> contains) { + throw new UnsupportedOperationException("Contains not supported yet"); + } + @Override public FilterPredicate visit(And and) { return new Or(and.getLeft().accept(this), and.getRight().accept(this)); diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java index 1a9ea984f..b86a5ef09 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/Operators.java @@ -24,6 +24,7 @@ import java.io.Serializable; import java.util.Locale; import java.util.Objects; import java.util.Set; +import java.util.function.BiFunction; import org.apache.parquet.hadoop.metadata.ColumnPath; import org.apache.parquet.io.api.Binary; @@ -84,6 +85,8 @@ public final class Operators { public static interface SupportsLtGt extends SupportsEqNotEq {} // marker for columns that can be used with lt(), ltEq(), gt(), gtEq() + public static interface SupportsContains {} + public static final class IntColumn extends Column<Integer> implements SupportsLtGt { IntColumn(ColumnPath columnPath) { super(columnPath, Integer.class); @@ -169,7 +172,7 @@ public final class Operators { } } - public static final class Eq<T extends Comparable<T>> extends ColumnFilterPredicate<T> { + public static final class Eq<T extends Comparable<T>> extends ColumnFilterPredicate<T> implements SupportsContains { // value can be null public Eq(Column<T> column, T value) { @@ -315,6 +318,140 @@ public final class Operators { } } + public abstract static class Contains<T extends Comparable<T>> implements FilterPredicate, Serializable { + private final Column<T> column; + + protected Contains(Column<T> column) { + this.column = Objects.requireNonNull(column, "column cannot be null"); + } + + static <ColumnT extends Comparable<ColumnT>, C extends ColumnFilterPredicate<ColumnT> & SupportsContains> + Contains<ColumnT> of(C pred) { + return new ContainsColumnPredicate<>(pred); + } + + public Column<T> getColumn() { + return column; + } + + @Override + public <R> R accept(Visitor<R> visitor) { + return visitor.visit(this); + } + + /** + * Applies a filtering Vistitor to the Contains predicate, traversing any composed And or Or clauses, + * and finally delegating to the underlying ColumnFilterPredicate. + */ + public abstract <R> R filter( + Visitor<R> visitor, BiFunction<R, R, R> andBehavior, BiFunction<R, R, R> orBehavior); + + Contains<T> and(FilterPredicate other) { + return new ContainsComposedPredicate<>(this, (Contains<T>) other, ContainsComposedPredicate.Combinator.AND); + } + + Contains<T> or(FilterPredicate other) { + return new ContainsComposedPredicate<>(this, (Contains<T>) other, ContainsComposedPredicate.Combinator.OR); + } + } + + private static class ContainsComposedPredicate<T extends Comparable<T>> extends Contains<T> { + private final Contains<T> left; + private final Contains<T> right; + + private final Combinator combinator; + + private enum Combinator { + AND, + OR + } + + ContainsComposedPredicate(Contains<T> left, Contains<T> right, Combinator combinator) { + super(Objects.requireNonNull(left, "left predicate cannot be null").getColumn()); + + if (!left.getColumn() + .columnPath + .equals(Objects.requireNonNull(right, "right predicate cannot be null") + .getColumn() + .columnPath)) { + throw new IllegalArgumentException("Composed Contains predicates must reference the same column name; " + + "found [" + left.getColumn().columnPath.toDotString() + ", " + + right.getColumn().columnPath.toDotString() + "]"); + } + + this.left = left; + this.right = right; + this.combinator = combinator; + } + + @Override + public <R> R filter(Visitor<R> visitor, BiFunction<R, R, R> andBehavior, BiFunction<R, R, R> orBehavior) { + final R filterLeft = left.filter(visitor, andBehavior, orBehavior); + final R filterRight = right.filter(visitor, andBehavior, orBehavior); + + if (combinator == Combinator.AND) { + return andBehavior.apply(filterLeft, filterRight); + } else { + return orBehavior.apply(filterLeft, filterRight); + } + } + + @Override + public String toString() { + return combinator.toString().toLowerCase() + "(" + left + ", " + right + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ContainsComposedPredicate<T> that = (ContainsComposedPredicate<T>) o; + return left.equals(that.left) && right.equals(that.right); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().getName(), left, right); + } + } + + private static class ContainsColumnPredicate<T extends Comparable<T>, U extends ColumnFilterPredicate<T>> + extends Contains<T> { + private final U underlying; + + ContainsColumnPredicate(U underlying) { + super(underlying.getColumn()); + if (underlying.getValue() == null) { + throw new IllegalArgumentException("Contains predicate does not support null element value"); + } + this.underlying = underlying; + } + + @Override + public String toString() { + String name = Contains.class.getSimpleName().toLowerCase(Locale.ENGLISH); + return name + "(" + underlying.toString() + ")"; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ContainsColumnPredicate<T, U> that = (ContainsColumnPredicate<T, U>) o; + return underlying.equals(that.underlying); + } + + @Override + public int hashCode() { + return Objects.hash(getClass().getName(), underlying); + } + + @Override + public <R> R filter(Visitor<R> visitor, BiFunction<R, R, R> andBehavior, BiFunction<R, R, R> orBehavior) { + return underlying.accept(visitor); + } + } + public static final class NotIn<T extends Comparable<T>> extends SetColumnFilterPredicate<T> { NotIn(Column<T> column, Set<T> values) { diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java index c8997a9e2..b5708a4a0 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/predicate/SchemaCompatibilityValidator.java @@ -25,6 +25,7 @@ import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Column; import org.apache.parquet.filter2.predicate.Operators.ColumnFilterPredicate; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -128,6 +129,12 @@ public class SchemaCompatibilityValidator implements FilterPredicate.Visitor<Voi return null; } + @Override + public <T extends Comparable<T>> Void visit(Contains<T> pred) { + validateColumnFilterPredicate(pred); + return null; + } + @Override public Void visit(And and) { and.getLeft().accept(this); @@ -167,7 +174,15 @@ public class SchemaCompatibilityValidator implements FilterPredicate.Visitor<Voi validateColumn(pred.getColumn()); } + private <T extends Comparable<T>> void validateColumnFilterPredicate(Contains<T> pred) { + validateColumn(pred.getColumn(), true); + } + private <T extends Comparable<T>> void validateColumn(Column<T> column) { + validateColumn(column, false); + } + + private <T extends Comparable<T>> void validateColumn(Column<T> column, boolean shouldBeRepeated) { ColumnPath path = column.getColumnPath(); Class<?> alreadySeen = columnTypesEncountered.get(path); @@ -189,7 +204,11 @@ public class SchemaCompatibilityValidator implements FilterPredicate.Visitor<Voi return; } - if (descriptor.getMaxRepetitionLevel() > 0) { + if (shouldBeRepeated && descriptor.getMaxRepetitionLevel() == 0) { + throw new IllegalArgumentException( + "FilterPredicate for column " + path.toDotString() + " requires a repeated " + + "schema, but found max repetition level " + descriptor.getMaxRepetitionLevel()); + } else if (!shouldBeRepeated && descriptor.getMaxRepetitionLevel() > 0) { throw new IllegalArgumentException("FilterPredicates do not currently support repeated columns. " + "Column " + path.toDotString() + " is repeated."); } diff --git a/parquet-column/src/main/java/org/apache/parquet/filter2/recordlevel/IncrementallyUpdatedFilterPredicate.java b/parquet-column/src/main/java/org/apache/parquet/filter2/recordlevel/IncrementallyUpdatedFilterPredicate.java index f1d7774d9..3c28ba6af 100644 --- a/parquet-column/src/main/java/org/apache/parquet/filter2/recordlevel/IncrementallyUpdatedFilterPredicate.java +++ b/parquet-column/src/main/java/org/apache/parquet/filter2/recordlevel/IncrementallyUpdatedFilterPredicate.java @@ -98,7 +98,7 @@ public interface IncrementallyUpdatedFilterPredicate { /** * Reset to clear state and begin evaluating the next record. */ - public final void reset() { + public void reset() { isKnown = false; result = false; } diff --git a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java index bc5c809e0..f4fe80ab9 100644 --- a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java +++ b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/ColumnIndexBuilder.java @@ -41,6 +41,7 @@ import org.apache.parquet.column.MinMax; import org.apache.parquet.column.statistics.SizeStatistics; import org.apache.parquet.column.statistics.Statistics; import org.apache.parquet.filter2.predicate.Operators.And; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -368,6 +369,11 @@ public abstract class ColumnIndexBuilder { return IndexIterator.all(getPageCount()); } + @Override + public <T extends Comparable<T>> PrimitiveIterator.OfInt visit(Contains<T> contains) { + return contains.filter(this, IndexIterator::intersection, IndexIterator::union); + } + @Override public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> PrimitiveIterator.OfInt visit( UserDefined<T, U> udp) { diff --git a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/IndexIterator.java b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/IndexIterator.java index bb0bdc849..86f93c4e9 100644 --- a/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/IndexIterator.java +++ b/parquet-column/src/main/java/org/apache/parquet/internal/column/columnindex/IndexIterator.java @@ -55,6 +55,119 @@ public class IndexIterator implements PrimitiveIterator.OfInt { return new IndexIterator(from, to + 1, i -> true, translator); } + static PrimitiveIterator.OfInt intersection(PrimitiveIterator.OfInt lhs, PrimitiveIterator.OfInt rhs) { + return new PrimitiveIterator.OfInt() { + private int next = fetchNext(); + + @Override + public int nextInt() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + int result = next; + next = fetchNext(); + return result; + } + + @Override + public boolean hasNext() { + return next != -1; + } + + private int fetchNext() { + if (!lhs.hasNext() || !rhs.hasNext()) { + return -1; + } + + // Since we know both iterators are in sorted order, we can iterate linearly through until + // we find the next value that belongs to both iterators, or terminate if none exist + int nextL = lhs.next(); + int nextR = rhs.next(); + + while (true) { + // Try to iterate LHS and RHS to the next intersecting value + while (nextL < nextR && lhs.hasNext()) { + nextL = lhs.next(); + } + while (nextR < nextL && rhs.hasNext()) { + nextR = rhs.next(); + } + if (nextL == nextR) { + return nextL; + } + + // No intersection found; advance LHS to the next element and retry loop + if (nextL < nextR && lhs.hasNext()) { + nextL = lhs.next(); + } else { + break; + } + } + + return -1; + } + }; + } + + static PrimitiveIterator.OfInt union(PrimitiveIterator.OfInt lhs, PrimitiveIterator.OfInt rhs) { + return new PrimitiveIterator.OfInt() { + private int peekL = -1; + private int peekR = -1; + private int next = fetchNext(); + + @Override + public int nextInt() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + int result = next; + next = fetchNext(); + return result; + } + + @Override + public boolean hasNext() { + return next != -1; + } + + private int fetchNext() { + if ((peekL == -1 && peekR == -1) && (!lhs.hasNext() && !rhs.hasNext())) { + return -1; + } + + if (peekL == -1 && lhs.hasNext()) { + peekL = lhs.next(); + } + + if (peekR == -1 && rhs.hasNext()) { + peekR = rhs.next(); + } + + // Return the smaller of the two next iterator values + int result; + if (peekL != -1 && (peekL == peekR || peekR == -1)) { + // If RHS is exhausted or intersects with LHS, return l and throw away r to avoid duplicates + result = peekL; + peekL = -1; + peekR = -1; + } else if (peekL == -1 && peekR != -1) { + // If LHS is exhausted, return RHS + result = peekR; + peekR = -1; + } else if (peekL < peekR) { + // If LHS value is smaller than RHS value, return LHS + result = peekL; + peekL = -1; + } else { + // If RHS value is smaller than LHS value, return RHS + result = peekR; + peekR = -1; + } + return result; + } + }; + } + private IndexIterator(int startIndex, int endIndex, IntPredicate filter, IntUnaryOperator translator) { this.endIndex = endIndex; this.filter = filter; diff --git a/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java b/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java index 264212a24..e46673f01 100644 --- a/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java +++ b/parquet-column/src/main/java/org/apache/parquet/internal/filter2/columnindex/ColumnIndexFilter.java @@ -29,6 +29,7 @@ import org.apache.parquet.filter2.predicate.FilterPredicate.Visitor; import org.apache.parquet.filter2.predicate.Operators; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Column; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -155,6 +156,11 @@ public class ColumnIndexFilter implements Visitor<RowRanges> { return applyPredicate(notIn.getColumn(), ci -> ci.visit(notIn), isNull ? RowRanges.EMPTY : allRows()); } + @Override + public <T extends Comparable<T>> RowRanges visit(Contains<T> contains) { + return contains.filter(this, RowRanges::intersection, RowRanges::union); + } + @Override public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> RowRanges visit(UserDefined<T, U> udp) { return applyPredicate( diff --git a/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestContainsRewriter.java b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestContainsRewriter.java new file mode 100644 index 000000000..5daa5d79b --- /dev/null +++ b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestContainsRewriter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.parquet.filter2.predicate; + +import static org.apache.parquet.filter2.predicate.ContainsRewriter.rewrite; +import static org.apache.parquet.filter2.predicate.FilterApi.and; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; +import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.eq; +import static org.apache.parquet.filter2.predicate.FilterApi.gt; +import static org.apache.parquet.filter2.predicate.FilterApi.gtEq; +import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.lt; +import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; +import static org.apache.parquet.filter2.predicate.FilterApi.notEq; +import static org.apache.parquet.filter2.predicate.FilterApi.or; +import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; +import static org.junit.Assert.assertEquals; + +import org.apache.parquet.filter2.predicate.Operators.Contains; +import org.apache.parquet.filter2.predicate.Operators.DoubleColumn; +import org.apache.parquet.filter2.predicate.Operators.IntColumn; +import org.apache.parquet.filter2.predicate.Operators.UserDefined; +import org.junit.Test; + +public class TestContainsRewriter { + private static final IntColumn intColumn = intColumn("a.b.c"); + private static final DoubleColumn doubleColumn = doubleColumn("a.b.c"); + + private static void assertNoOp(FilterPredicate p) { + assertEquals(p, rewrite(p)); + } + + @Test + public void testBaseCases() { + UserDefined<Integer, DummyUdp> ud = userDefined(intColumn, DummyUdp.class); + + assertNoOp(eq(intColumn, 17)); + assertNoOp(notEq(intColumn, 17)); + assertNoOp(lt(intColumn, 17)); + assertNoOp(ltEq(intColumn, 17)); + assertNoOp(gt(intColumn, 17)); + assertNoOp(gtEq(intColumn, 17)); + assertNoOp(and(eq(intColumn, 17), eq(doubleColumn, 12.0))); + assertNoOp(or(eq(intColumn, 17), eq(doubleColumn, 12.0))); + assertNoOp(ud); + + Contains<Integer> containsLhs = contains(eq(intColumn, 17)); + Contains<Integer> containsRhs = contains(eq(intColumn, 7)); + + assertNoOp(containsLhs); + assertEquals(containsLhs.and(containsRhs), rewrite(and(containsLhs, containsRhs))); + assertEquals(containsLhs.or(containsRhs), rewrite(or(containsLhs, containsRhs))); + } + + @Test + public void testNested() { + Contains<Integer> contains1 = contains(eq(intColumn, 1)); + Contains<Integer> contains2 = contains(eq(intColumn, 2)); + Contains<Integer> contains3 = contains(eq(intColumn, 3)); + Contains<Integer> contains4 = contains(eq(intColumn, 4)); + + assertEquals(contains1.and(contains2.or(contains3)), rewrite(and(contains1, or(contains2, contains3)))); + assertEquals(contains1.and(contains2).or(contains3), rewrite(or(and(contains1, contains2), contains3))); + + assertEquals( + contains1.and(contains2).and(contains2.or(contains3)), + rewrite(and(and(contains1, contains2), or(contains2, contains3)))); + assertEquals( + contains1.and(contains2).or(contains3.or(contains4)), + rewrite(or(and(contains1, contains2), or(contains3, contains4)))); + } +} diff --git a/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestFilterApiMethods.java b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestFilterApiMethods.java index f5050cac4..c2e1ef385 100644 --- a/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestFilterApiMethods.java +++ b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestFilterApiMethods.java @@ -20,6 +20,7 @@ package org.apache.parquet.filter2.predicate; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; @@ -31,6 +32,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.predicate.Operators.NotEq; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import java.io.ByteArrayInputStream; @@ -91,6 +93,23 @@ public class TestFilterApiMethods { assertEquals(ColumnPath.get("x", "y", "z"), ((Gt) gt).getColumn().getColumnPath()); } + @Test + public void testInvalidContainsCreation() { + assertThrows( + "Contains predicate does not support null element value", + IllegalArgumentException.class, + () -> contains(eq(binColumn, null))); + + assertThrows( + "Composed Contains predicates must reference the same column name; found [a.b.c, b.c.d]", + IllegalArgumentException.class, + () -> ContainsRewriter.rewrite(or( + contains(eq(binaryColumn("a.b.c"), Binary.fromString("foo"))), + and( + contains(eq(binaryColumn("b.c.d"), Binary.fromString("bar"))), + contains(eq(binaryColumn("b.c.d"), Binary.fromString("bar"))))))); + } + @Test public void testToString() { FilterPredicate pred = or(predicate, notEq(binColumn, Binary.fromString("foobarbaz"))); @@ -98,6 +117,15 @@ public class TestFilterApiMethods { "or(and(not(or(eq(a.b.c, 7), noteq(a.b.c, 17))), gt(x.y.z, 100.0)), " + "noteq(a.string.column, Binary{\"foobarbaz\"}))", pred.toString()); + + pred = ContainsRewriter.rewrite(or( + contains(eq(binColumn, Binary.fromString("foo"))), + and( + contains(eq(binColumn, Binary.fromString("bar"))), + contains(eq(binColumn, Binary.fromString("baz")))))); + assertEquals( + "or(contains(eq(a.string.column, Binary{\"foo\"})), and(contains(eq(a.string.column, Binary{\"bar\"})), contains(eq(a.string.column, Binary{\"baz\"}))))", + pred.toString()); } @Test diff --git a/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestSchemaCompatibilityValidator.java b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestSchemaCompatibilityValidator.java index fc23ce7e2..47e9bdd5e 100644 --- a/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestSchemaCompatibilityValidator.java +++ b/parquet-column/src/test/java/org/apache/parquet/filter2/predicate/TestSchemaCompatibilityValidator.java @@ -20,6 +20,7 @@ package org.apache.parquet.filter2.predicate; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; import static org.apache.parquet.filter2.predicate.FilterApi.intColumn; @@ -126,7 +127,7 @@ public class TestSchemaCompatibilityValidator { } @Test - public void testRepeatedNotSupported() { + public void testRepeatedNotSupportedForPrimitivePredicates() { try { validate(eq(lotsOfLongs, 10l), schema); fail("this should throw"); @@ -136,4 +137,27 @@ public class TestSchemaCompatibilityValidator { e.getMessage()); } } + + @Test + public void testRepeatedSupportedForContainsPredicates() { + try { + validate(contains(eq(lotsOfLongs, 10L)), schema); + validate(and(contains(eq(lotsOfLongs, 10L)), contains(eq(lotsOfLongs, 5l))), schema); + validate(or(contains(eq(lotsOfLongs, 10L)), contains(eq(lotsOfLongs, 5l))), schema); + } catch (IllegalArgumentException e) { + fail("Valid repeated column predicates should not throw exceptions"); + } + } + + @Test + public void testNonRepeatedNotSupportedForContainsPredicates() { + try { + validate(contains(eq(longBar, 10L)), schema); + fail("Non-repeated field " + longBar + " should fail to validate a containsEq() predicate"); + } catch (IllegalArgumentException e) { + assertEquals( + "FilterPredicate for column x.bar requires a repeated schema, but found max repetition level 0", + e.getMessage()); + } + } } diff --git a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java index 0631300fc..58a899eef 100644 --- a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java +++ b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestColumnIndexBuilder.java @@ -19,8 +19,10 @@ package org.apache.parquet.internal.column.columnindex; import static java.util.Arrays.asList; +import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; import static org.apache.parquet.filter2.predicate.FilterApi.booleanColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.floatColumn; @@ -33,6 +35,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.lt; import static org.apache.parquet.filter2.predicate.FilterApi.ltEq; import static org.apache.parquet.filter2.predicate.FilterApi.notEq; import static org.apache.parquet.filter2.predicate.FilterApi.notIn; +import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.apache.parquet.filter2.predicate.LogicalInverter.invert; import static org.apache.parquet.schema.OriginalType.DECIMAL; @@ -62,6 +65,7 @@ import java.util.List; import java.util.Set; import org.apache.parquet.bytes.BytesUtils; import org.apache.parquet.column.statistics.Statistics; +import org.apache.parquet.filter2.predicate.ContainsRewriter; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.Operators.BinaryColumn; import org.apache.parquet.filter2.predicate.Operators.BooleanColumn; @@ -236,6 +240,58 @@ public class TestColumnIndexBuilder { } } + @Test + public void testArrayContainsDouble() { + PrimitiveType type = Types.required(DOUBLE).named("test_double"); + ColumnIndexBuilder builder = ColumnIndexBuilder.getBuilder(type, Integer.MAX_VALUE); + assertThat(builder, instanceOf(DoubleColumnIndexBuilder.class)); + assertNull(builder.build()); + DoubleColumn col = doubleColumn("test_col"); + + StatsBuilder sb = new StatsBuilder(); + builder.add(sb.stats(type, -4.2, -4.1)); + builder.add(sb.stats(type, -11.7, 7.0, null)); + builder.add(sb.stats(type, 2.2, 2.2, null, null)); + builder.add(sb.stats(type, null, null, null)); + builder.add(sb.stats(type, 1.9, 2.32)); + builder.add(sb.stats(type, -21.0, 8.1)); + builder.add(sb.stats(type, 10.0, 25.0)); + assertEquals(7, builder.getPageCount()); + assertEquals(sb.getMinMaxSize(), builder.getMinMaxSize()); + ColumnIndex columnIndex = builder.build(); + assertEquals(BoundaryOrder.UNORDERED, columnIndex.getBoundaryOrder()); + assertCorrectNullCounts(columnIndex, 0, 1, 2, 3, 0, 0, 0); + assertCorrectNullPages(columnIndex, false, false, false, true, false, false, false); + assertCorrectValues(columnIndex.getMaxValues(), -4.1, 7.0, 2.2, null, 2.32, 8.1, 25.0); + assertCorrectValues(columnIndex.getMinValues(), -4.2, -11.7, 2.2, null, 1.9, -21.0, 10.0); + + // Validate that contains(eq()) matches eq() when not combined using or() and and() + assertCorrectFiltering(columnIndex, eq(col, 0.0), 1, 5); + assertCorrectFiltering(columnIndex, contains(eq(col, 0.0)), 1, 5); + + assertCorrectFiltering(columnIndex, eq(col, 2.2), 1, 2, 4, 5); + assertCorrectFiltering(columnIndex, contains(eq(col, 2.2)), 1, 2, 4, 5); + + assertCorrectFiltering(columnIndex, eq(col, 25.0), 6); + assertCorrectFiltering(columnIndex, contains(eq(col, 25.0)), 6); + + // Should equal intersection of [1, 5] and [1, 2, 4, 5] --> [1, 5] + assertCorrectFiltering( + columnIndex, ContainsRewriter.rewrite(and(contains(eq(col, 0.0)), contains(eq(col, 2.2)))), 1, 5); + + // Should equal intersection of [6] and [1, 5] --> [] + assertCorrectFiltering( + columnIndex, ContainsRewriter.rewrite(and(contains(eq(col, 25.0)), contains(eq(col, 0.0))))); + + // Should equal union of [6] and [1, 5] --> [1, 5, 6] + assertCorrectFiltering( + columnIndex, ContainsRewriter.rewrite(or(contains(eq(col, 25.0)), contains(eq(col, 0.0)))), 1, 5, 6); + + // Should equal de-duplicated union of [1, 5] and [1, 2, 4, 5] --> [1, 2, 4, 5] + assertCorrectFiltering( + columnIndex, ContainsRewriter.rewrite(or(contains(eq(col, 0.0)), contains(eq(col, 2.2)))), 1, 2, 4, 5); + } + @Test public void testBuildBinaryDecimal() { PrimitiveType type = @@ -286,6 +342,7 @@ public class TestColumnIndexBuilder { set1.add(Binary.fromString("0.0")); assertCorrectFiltering(columnIndex, in(col, set1), 1, 4); assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 1, 2, 3, 4, 5, 6, 7); + set1.add(null); assertCorrectFiltering(columnIndex, in(col, set1), 0, 1, 2, 3, 4, 5, 6); assertCorrectFiltering(columnIndex, notIn(col, set1), 0, 1, 2, 3, 4, 5, 6, 7); diff --git a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestIndexIterator.java b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestIndexIterator.java index 8075b4165..cfbfeed1d 100644 --- a/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestIndexIterator.java +++ b/parquet-column/src/test/java/org/apache/parquet/internal/column/columnindex/TestIndexIterator.java @@ -50,6 +50,97 @@ public class TestIndexIterator { assertEquals(IndexIterator.rangeTranslate(11, 18, i -> i - 10), 1, 2, 3, 4, 5, 6, 7, 8); } + @Test + public void testUnion() { + // Test deduplication of intersecting ranges + assertEquals( + IndexIterator.union( + IndexIterator.rangeTranslate(0, 7, i -> i), IndexIterator.rangeTranslate(4, 10, i -> i)), + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10); + + // Test inversion of LHS and RHS + assertEquals( + IndexIterator.union( + IndexIterator.rangeTranslate(4, 10, i -> i), IndexIterator.rangeTranslate(0, 7, i -> i)), + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10); + + // Test non-intersecting ranges + assertEquals( + IndexIterator.union( + IndexIterator.rangeTranslate(2, 5, i -> i), IndexIterator.rangeTranslate(8, 10, i -> i)), + 2, + 3, + 4, + 5, + 8, + 9, + 10); + } + + @Test + public void testIntersection() { + // Case 1: some overlap between LHS and RHS + // LHS: [0, 1, 2, 3, 4, 5, 6, 7], RHS: [4, 5, 6, 7, 8, 9, 10] + assertEquals( + IndexIterator.intersection( + IndexIterator.rangeTranslate(0, 7, i -> i), IndexIterator.rangeTranslate(4, 10, i -> i)), + 4, + 5, + 6, + 7); + + // Test inversion of LHS and RHS + assertEquals( + IndexIterator.intersection( + IndexIterator.rangeTranslate(4, 10, i -> i), IndexIterator.rangeTranslate(0, 7, i -> i)), + 4, + 5, + 6, + 7); + + // Case 2: Single point of overlap at end of iterator + // LHS: [1, 3, 5, 7], RHS: [0, 2, 4, 6, 7] + assertEquals( + IndexIterator.intersection( + IndexIterator.filter(8, i -> i % 2 == 1), IndexIterator.filter(8, i -> i % 2 == 0 || i == 7)), + 7); + + // Test inversion of LHS and RHS + assertEquals( + IndexIterator.intersection( + IndexIterator.filter(8, i -> i % 2 == 0 || i == 7), IndexIterator.filter(8, i -> i % 2 == 1)), + 7); + + // Test no intersection between ranges + // LHS: [2, 3, 4, 5], RHS: [8, 9, 10] + assertEquals(IndexIterator.intersection( + IndexIterator.rangeTranslate(2, 5, i -> i), IndexIterator.rangeTranslate(8, 10, i -> i))); + + // Test inversion of LHS and RHS + assertEquals(IndexIterator.intersection( + IndexIterator.rangeTranslate(8, 10, i -> i), IndexIterator.rangeTranslate(2, 5, i -> i))); + } + static void assertEquals(PrimitiveIterator.OfInt actualIt, int... expectedValues) { IntList actualList = new IntArrayList(); actualIt.forEachRemaining((int value) -> actualList.add(value)); diff --git a/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java b/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java index 5fb441521..1574ce247 100644 --- a/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java +++ b/parquet-column/src/test/java/org/apache/parquet/internal/filter2/columnindex/TestColumnIndexFilter.java @@ -22,6 +22,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; import static org.apache.parquet.filter2.predicate.FilterApi.booleanColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; @@ -47,6 +48,7 @@ import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.DOUBLE; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; import static org.apache.parquet.schema.Types.optional; +import static org.apache.parquet.schema.Types.repeated; import static org.junit.Assert.assertArrayEquals; import it.unimi.dsi.fastutil.longs.LongArrayList; @@ -292,6 +294,26 @@ public class TestColumnIndexFilter { .build(); private static final OffsetIndex COLUMN5_OI = new OIBuilder().addPage(1).addPage(29).build(); + + private static final ColumnIndex COLUMN6_CI = new CIBuilder(repeated(INT32).named("column6"), ASCENDING) + .addPage(0, 1, 1) + .addPage(1, 2, 6) + .addPage(0, 7, 7) + .addPage(1, 7, 10) + .addPage(0, 11, 17) + .addPage(0, 18, 23) + .addPage(0, 24, 26) + .build(); + private static final OffsetIndex COLUMN6_OI = new OIBuilder() + .addPage(1) + .addPage(6) + .addPage(2) + .addPage(5) + .addPage(7) + .addPage(6) + .addPage(3) + .build(); + private static final ColumnIndexStore STORE = new ColumnIndexStore() { @Override public ColumnIndex getColumnIndex(ColumnPath column) { @@ -306,6 +328,8 @@ public class TestColumnIndexFilter { return COLUMN4_CI; case "column5": return COLUMN5_CI; + case "column6": + return COLUMN6_CI; default: return null; } @@ -324,6 +348,8 @@ public class TestColumnIndexFilter { return COLUMN4_OI; case "column5": return COLUMN5_OI; + case "column6": + return COLUMN6_OI; default: throw new MissingOffsetIndexException(column); } @@ -354,7 +380,7 @@ public class TestColumnIndexFilter { @Test public void testFiltering() { - Set<ColumnPath> paths = paths("column1", "column2", "column3", "column4"); + Set<ColumnPath> paths = paths("column1", "column2", "column3", "column4", "column6"); assertAllRows( calculateRowRanges( @@ -364,6 +390,48 @@ public class TestColumnIndexFilter { TOTAL_ROW_COUNT), TOTAL_ROW_COUNT); + assertRows( + calculateRowRanges( + FilterCompat.get(contains(eq(intColumn("column6"), 7))), STORE, paths, TOTAL_ROW_COUNT), + 7, + 8, + 9, + 10, + 11, + 12, + 13); + assertRows( + calculateRowRanges( + FilterCompat.get( + and(contains(eq(intColumn("column6"), 7)), contains(eq(intColumn("column6"), 10)))), + STORE, + paths, + TOTAL_ROW_COUNT), + 9, + 10, + 11, + 12, + 13); + assertRows( + calculateRowRanges( + FilterCompat.get( + or(contains(eq(intColumn("column6"), 7)), contains(eq(intColumn("column6"), 20)))), + STORE, + paths, + TOTAL_ROW_COUNT), + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 21, + 22, + 23, + 24, + 25, + 26); Set<Integer> set1 = new HashSet<>(); set1.add(7); assertRows( diff --git a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java index 662be8c41..1a2f5e54e 100644 --- a/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java +++ b/parquet-generator/src/main/java/org/apache/parquet/filter2/IncrementallyUpdatedFilterPredicateGenerator.java @@ -70,6 +70,9 @@ public class IncrementallyUpdatedFilterPredicateGenerator { + "import java.util.Set;\n" + "\n" + "import org.apache.parquet.hadoop.metadata.ColumnPath;\n" + + "import org.apache.parquet.filter2.predicate.FilterPredicate;\n" + + "import org.apache.parquet.filter2.predicate.Operators;\n" + + "import org.apache.parquet.filter2.predicate.Operators.Contains;\n" + "import org.apache.parquet.filter2.predicate.Operators.Eq;\n" + "import org.apache.parquet.filter2.predicate.Operators.Gt;\n" + "import org.apache.parquet.filter2.predicate.Operators.GtEq;\n" @@ -100,13 +103,13 @@ public class IncrementallyUpdatedFilterPredicateGenerator { addVisitBegin("Eq"); for (TypeInfo info : TYPES) { - addEqNotEqCase(info, true); + addEqNotEqCase(info, true, false); } addVisitEnd(); addVisitBegin("NotEq"); for (TypeInfo info : TYPES) { - addEqNotEqCase(info, false); + addEqNotEqCase(info, false, false); } addVisitEnd(); @@ -122,6 +125,12 @@ public class IncrementallyUpdatedFilterPredicateGenerator { } addVisitEnd(); + addContainsBegin(); + addVisitBegin("Contains"); + addContainsCase(); + addContainsEnd(); + addVisitEnd(); + addVisitBegin("Lt"); for (TypeInfo info : TYPES) { addInequalityCase(info, "<"); @@ -186,21 +195,27 @@ public class IncrementallyUpdatedFilterPredicateGenerator { + " }\n\n"); } - private void addEqNotEqCase(TypeInfo info, boolean isEq) throws IOException { - add(" if (clazz.equals(" + info.className + ".class)) {\n" + " if (pred.getValue() == null) {\n" - + " valueInspector = new ValueInspector() {\n" - + " @Override\n" - + " public void updateNull() {\n" - + " setResult(" - + isEq + ");\n" + " }\n" - + "\n" - + " @Override\n" - + " public void update(" - + info.primitiveName + " value) {\n" + " setResult(" - + !isEq + ");\n" + " }\n" - + " };\n" - + " } else {\n" - + " final " + private void addEqNotEqCase(TypeInfo info, boolean isEq, boolean expectMultipleResults) throws IOException { + add(" if (clazz.equals(" + info.className + ".class)) {\n"); + + // Predicates for repeated fields don't need to support null values + if (!expectMultipleResults) { + add(" if (pred.getValue() == null) {\n" + + " valueInspector = new ValueInspector() {\n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(" + + isEq + ");\n" + " }\n" + + "\n" + + " @Override\n" + + " public void update(" + + info.primitiveName + " value) {\n" + " setResult(" + + !isEq + ");\n" + " }\n" + + " };\n" + + " } else {\n"); + } + + add(" final " + info.primitiveName + " target = (" + info.className + ") (Object) pred.getValue();\n" + " final PrimitiveComparator<" + info.className + "> comparator = getComparator(columnPath);\n" + "\n" @@ -214,9 +229,20 @@ public class IncrementallyUpdatedFilterPredicateGenerator { + " public void update(" + info.primitiveName + " value) {\n"); - add(" setResult(" + compareEquality("value", "target", isEq) + ");\n"); + if (!expectMultipleResults) { + add(" setResult(" + compareEquality("value", "target", isEq) + ");\n"); + } else { + add(" if (!isKnown() && " + compareEquality("value", "target", isEq) + + ") { setResult(true); }\n"); + } + + add(" }\n };\n"); - add(" }\n" + " };\n" + " }\n" + " }\n\n"); + if (!expectMultipleResults) { + add(" }\n"); + } + + add(" }\n\n"); } private void addInequalityCase(TypeInfo info, String op) throws IOException { @@ -304,6 +330,196 @@ public class IncrementallyUpdatedFilterPredicateGenerator { + "\n"); } + private void addContainsUpdateCase(TypeInfo info, String... inspectors) throws IOException { + add(" @Override\n" + " public void update(" + info.primitiveName + " value) {\n"); + for (String inspector : inspectors) { + add(" " + inspector + ".update(value);\n"); + } + add(" checkSatisfied();\n" + " }\n"); + } + + private void addContainsInspectorVisitor(String op, boolean isSupported) throws IOException { + if (isSupported) { + add(" @Override\n" + + " public <T extends Comparable<T>> ValueInspector visit(" + op + "<T> pred) {\n" + + " ColumnPath columnPath = pred.getColumn().getColumnPath();\n" + + " Class<T> clazz = pred.getColumn().getColumnType();\n" + + " ValueInspector valueInspector = null;\n"); + + for (TypeInfo info : TYPES) { + switch (op) { + case "Eq": + addEqNotEqCase(info, true, true); + break; + default: + throw new UnsupportedOperationException("Op " + op + " not implemented for Contains filter"); + } + } + + add(" return valueInspector;" + " }\n"); + } else { + add(" @Override\n" + + " public <T extends Comparable<T>> ValueInspector visit(" + op + "<T> pred) {\n" + + " throw new UnsupportedOperationException(\"" + op + + " not supported for Contains predicate\");\n" + + " }\n" + + "\n"); + } + } + + private void addContainsBegin() throws IOException { + add(" private static class ContainsPredicate extends ValueInspector {\n" + + " private final ValueInspector inspector;\n" + + "\n" + + " private ContainsPredicate(ValueInspector inspector) {\n" + + " this.inspector = inspector;\n" + + " }\n" + + "\n" + + " private void checkSatisfied() {\n" + + " if (!isKnown() && inspector.isKnown() && inspector.getResult()) {\n" + + " setResult(true);\n" + + " }\n" + + " }\n" + + "\n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(false);\n" + + " }\n"); + + for (TypeInfo info : TYPES) { + addContainsUpdateCase(info, "inspector"); + } + + add(" @Override\n" + " public void reset() {\n" + + " super.reset();\n" + + " inspector.reset();\n" + + " }\n" + + " }\n"); + + add(" private static class ContainsAndPredicate extends ValueInspector {\n" + + " private final ValueInspector left;\n" + + " private final ValueInspector right;\n" + + "\n" + + " private ContainsAndPredicate(ValueInspector left, ValueInspector right) {\n" + + " this.left = left;\n" + + " this.right = right;\n" + + " }\n" + + "\n" + + " private void checkSatisfied() {\n" + + " if (isKnown()) { return; }\n" + + " if (left.isKnown() && right.isKnown() && left.getResult() && right.getResult()) {\n" + + " setResult(true);\n" + + " }\n" + + " }\n" + + " \n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(false);\n" + + " }\n\n"); + + for (TypeInfo info : TYPES) { + addContainsUpdateCase(info, "left", "right"); + } + + add(" @Override\n" + + " public void reset() {\n" + + " super.reset();\n" + + " left.reset();\n" + + " right.reset();\n" + + " }\n" + + " }\n"); + + add(" private static class ContainsOrPredicate extends ValueInspector {\n" + + " private final ValueInspector left;\n" + + " private final ValueInspector right;\n" + + "\n" + + " private ContainsOrPredicate(ValueInspector left, ValueInspector right) {\n" + + " this.left = left;\n" + + " this.right = right;\n" + + " }\n" + + "\n" + + " private void checkSatisfied() {\n" + + " if (isKnown()) { return; }\n" + + " if (left.isKnown() && left.getResult()) {\n" + + " setResult(true);\n" + + " return;\n" + + " }\n" + + " if (right.isKnown() && right.getResult()) {\n" + + " setResult(true);\n" + + " }\n" + + " }\n" + + " \n" + + " @Override\n" + + " public void updateNull() {\n" + + " setResult(false);\n" + + " }\n"); + + for (TypeInfo info : TYPES) { + addContainsUpdateCase(info, "left", "right"); + } + + add(" @Override\n" + + " public void reset() {\n" + + " super.reset();\n" + + " left.reset();\n" + + " right.reset();\n" + + " }\n" + + " }\n"); + + add(" private class ContainsInspectorVisitor implements FilterPredicate.Visitor<ValueInspector> {\n\n" + + " @Override\n" + + " public <T extends Comparable<T>> ValueInspector visit(Contains<T> contains) {\n" + + " return contains.filter(\n" + + " this,\n" + + " (l, r) -> new ContainsAndPredicate(l, r),\n" + + " (l, r) -> new ContainsOrPredicate(l, r)\n" + + " );\n" + + " }\n"); + + addContainsInspectorVisitor("Eq", true); + addContainsInspectorVisitor("NotEq", false); + addContainsInspectorVisitor("Lt", false); + addContainsInspectorVisitor("LtEq", false); + addContainsInspectorVisitor("Gt", false); + addContainsInspectorVisitor("GtEq", false); + + add(" @Override\n" + + " public ValueInspector visit(Operators.And pred) {\n" + + " throw new UnsupportedOperationException(\"Operators.And not supported for Contains predicate\");\n" + + " }\n" + + "\n" + + " @Override\n" + + " public ValueInspector visit(Operators.Or pred) {\n" + + " throw new UnsupportedOperationException(\"Operators.Or not supported for Contains predicate\");\n" + + " }\n" + + "\n" + + " @Override\n" + + " public ValueInspector visit(Operators.Not pred) {\n" + + " throw new UnsupportedOperationException(\"Operators.Not not supported for Contains predicate\");\n" + + " }" + + " @Override\n" + + " public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> ValueInspector visit(\n" + + " UserDefined<T, U> pred) {\n" + + " throw new UnsupportedOperationException(\"UserDefinedPredicate not supported for Contains predicate\");\n" + + " }\n" + + "\n" + + " @Override\n" + + " public <T extends Comparable<T>, U extends UserDefinedPredicate<T>> ValueInspector visit(\n" + + " LogicalNotUserDefined<T, U> pred) {\n" + + " throw new UnsupportedOperationException(\"LogicalNotUserDefined not supported for Contains predicate\");\n" + + " }\n" + + " }\n" + + "\n"); + } + + private void addContainsCase() throws IOException { + add(" valueInspector = new ContainsPredicate(new ContainsInspectorVisitor().visit(pred));\n"); + } + + private void addContainsEnd() { + // No-op + } + private void addUdpCase(TypeInfo info, boolean invert) throws IOException { add(" if (clazz.equals(" + info.className + ".class)) {\n" + " valueInspector = new ValueInspector() {\n" diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java index f1ed6ea8c..16348e535 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/bloomfilterlevel/BloomFilterImpl.java @@ -119,6 +119,11 @@ public class BloomFilterImpl implements FilterPredicate.Visitor<Boolean> { return BLOCK_MIGHT_MATCH; } + @Override + public <T extends Comparable<T>> Boolean visit(Operators.Contains<T> contains) { + return contains.filter(this, (l, r) -> l || r, (l, r) -> l && r); + } + @Override public <T extends Comparable<T>> Boolean visit(Operators.In<T> in) { Set<T> values = in.getValues(); diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java index 665e39e9f..dbb38047e 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilter.java @@ -36,6 +36,7 @@ import org.apache.parquet.column.page.DictionaryPageReadStore; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Column; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -487,6 +488,11 @@ public class DictionaryFilter implements FilterPredicate.Visitor<Boolean> { return BLOCK_MIGHT_MATCH; } + @Override + public <T extends Comparable<T>> Boolean visit(Contains<T> contains) { + return contains.filter(this, (l, r) -> l || r, (l, r) -> l && r); + } + @Override public Boolean visit(And and) { return and.getLeft().accept(this) || and.getRight().accept(this); diff --git a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java index 76f9078ca..deb4706d5 100644 --- a/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java +++ b/parquet-hadoop/src/main/java/org/apache/parquet/filter2/statisticslevel/StatisticsFilter.java @@ -28,6 +28,7 @@ import org.apache.parquet.column.statistics.Statistics; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.Operators.And; import org.apache.parquet.filter2.predicate.Operators.Column; +import org.apache.parquet.filter2.predicate.Operators.Contains; import org.apache.parquet.filter2.predicate.Operators.Eq; import org.apache.parquet.filter2.predicate.Operators.Gt; import org.apache.parquet.filter2.predicate.Operators.GtEq; @@ -211,6 +212,11 @@ public class StatisticsFilter implements FilterPredicate.Visitor<Boolean> { return BLOCK_MIGHT_MATCH; } + @Override + public <T extends Comparable<T>> Boolean visit(Contains<T> contains) { + return contains.filter(this, (l, r) -> l || r, (l, r) -> l && r); + } + @Override @SuppressWarnings("unchecked") public <T extends Comparable<T>> Boolean visit(NotEq<T> notEq) { diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java index 525b603bb..5b9e638d6 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/dictionarylevel/DictionaryFilterTest.java @@ -24,6 +24,7 @@ import static org.apache.parquet.column.ParquetProperties.WriterVersion.PARQUET_ import static org.apache.parquet.filter2.dictionarylevel.DictionaryFilter.canDrop; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.floatColumn; @@ -69,6 +70,7 @@ import org.apache.parquet.example.data.Group; import org.apache.parquet.example.data.simple.SimpleGroupFactory; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.LogicalInverseRewriter; +import org.apache.parquet.filter2.predicate.Operators; import org.apache.parquet.filter2.predicate.Operators.BinaryColumn; import org.apache.parquet.filter2.predicate.Operators.DoubleColumn; import org.apache.parquet.filter2.predicate.Operators.FloatColumn; @@ -112,6 +114,7 @@ public class DictionaryFilterTest { + "required int32 plain_int32_field; " + "required binary fallback_binary_field; " + "required int96 int96_field; " + + "repeated binary repeated_binary_field;" + "} "); private static final String ALPHABET = "abcdefghijklmnopqrstuvwxyz"; @@ -177,7 +180,16 @@ public class DictionaryFilterTest { i < (nElements / 2) ? ALPHABET.substring(index, index + 1) : UUID.randomUUID().toString()) - .append("int96_field", INT96_VALUES[i % INT96_VALUES.length]); + .append("int96_field", INT96_VALUES[i % INT96_VALUES.length]) + .append("repeated_binary_field", ALPHABET.substring(index, index + 1)); + + if (index + 1 < 26) { + group = group.append("repeated_binary_field", ALPHABET.substring(index + 1, index + 2)); + } + + if (index + 2 < 26) { + group = group.append("repeated_binary_field", ALPHABET.substring(index + 2, index + 3)); + } // 10% of the time, leave the field null if (index % 10 > 0) { @@ -282,7 +294,8 @@ public class DictionaryFilterTest { "int64_field", "double_field", "float_field", - "int96_field")); + "int96_field", + "repeated_binary_field")); for (ColumnChunkMetaData column : ccmd) { String name = column.getPath().toDotString(); if (dictionaryEncodedColumns.contains(name)) { @@ -319,7 +332,8 @@ public class DictionaryFilterTest { "int64_field", "double_field", "float_field", - "int96_field")); + "int96_field", + "repeated_binary_field")); for (ColumnChunkMetaData column : ccmd) { EncodingStats encStats = column.getEncodingStats(); String name = column.getPath().toDotString(); @@ -814,6 +828,42 @@ public class DictionaryFilterTest { canDrop(LogicalInverseRewriter.rewrite(not(userDefined(fake, nullRejecting))), ccmd, dictionaries)); } + @Test + public void testContainsAnd() throws Exception { + BinaryColumn col = binaryColumn("binary_field"); + + // both evaluate to false (no upper-case letters are in the dictionary) + Operators.Contains<Binary> B = contains(eq(col, Binary.fromString("B"))); + Operators.Contains<Binary> C = contains(eq(col, Binary.fromString("C"))); + + // both evaluate to true (all lower-case letters are in the dictionary) + Operators.Contains<Binary> x = contains(eq(col, Binary.fromString("x"))); + Operators.Contains<Binary> y = contains(eq(col, Binary.fromString("y"))); + + assertTrue("Should drop when either predicate must be false", canDrop(and(B, y), ccmd, dictionaries)); + assertTrue("Should drop when either predicate must be false", canDrop(and(x, C), ccmd, dictionaries)); + assertTrue("Should drop when either predicate must be false", canDrop(and(B, C), ccmd, dictionaries)); + assertFalse("Should not drop when either predicate could be true", canDrop(and(x, y), ccmd, dictionaries)); + } + + @Test + public void testContainsOr() throws Exception { + BinaryColumn col = binaryColumn("binary_field"); + + // both evaluate to false (no upper-case letters are in the dictionary) + Operators.Contains<Binary> B = contains(eq(col, Binary.fromString("B"))); + Operators.Contains<Binary> C = contains(eq(col, Binary.fromString("C"))); + + // both evaluate to true (all lower-case letters are in the dictionary) + Operators.Contains<Binary> x = contains(eq(col, Binary.fromString("x"))); + Operators.Contains<Binary> y = contains(eq(col, Binary.fromString("y"))); + + assertFalse("Should not drop when one predicate could be true", canDrop(or(B, y), ccmd, dictionaries)); + assertFalse("Should not drop when one predicate could be true", canDrop(or(x, C), ccmd, dictionaries)); + assertTrue("Should drop when both predicates must be false", canDrop(or(B, C), ccmd, dictionaries)); + assertFalse("Should not drop when one predicate could be true", canDrop(or(x, y), ccmd, dictionaries)); + } + private static final class InInt32UDP extends UserDefinedPredicate<Integer> implements Serializable { private final Set<Integer> ints; diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java index d4a77879e..a0ecfa377 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/PhoneBookWriter.java @@ -23,7 +23,9 @@ import static org.junit.Assert.assertEquals; import java.io.File; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.parquet.bytes.ByteBufferAllocator; @@ -54,6 +56,12 @@ public class PhoneBookWriter { + " optional binary kind (UTF8);\n" + " }\n" + " }\n" + + " optional group accounts (MAP) {\n" + + " repeated group key_value {\n" + + " required binary key;\n" + + " required double value;\n" + + " }\n" + + " }\n" + "}\n"; private static final MessageType schema = getSchema(); @@ -154,11 +162,19 @@ public class PhoneBookWriter { private final List<PhoneNumber> phoneNumbers; private final Location location; + private final Map<String, Double> accounts; + public User(long id, String name, List<PhoneNumber> phoneNumbers, Location location) { + this(id, name, phoneNumbers, location, null); + } + + public User( + long id, String name, List<PhoneNumber> phoneNumbers, Location location, Map<String, Double> accounts) { this.id = id; this.name = name; this.phoneNumbers = phoneNumbers; this.location = location; + this.accounts = accounts; } public long getId() { @@ -177,6 +193,10 @@ public class PhoneBookWriter { return location; } + public Map<String, Double> getAccounts() { + return accounts; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -189,6 +209,7 @@ public class PhoneBookWriter { if (name != null ? !name.equals(user.name) : user.name != null) return false; if (phoneNumbers != null ? !phoneNumbers.equals(user.phoneNumbers) : user.phoneNumbers != null) return false; + if (accounts != null ? !accounts.equals(user.accounts) : user.accounts != null) return false; return true; } @@ -199,17 +220,18 @@ public class PhoneBookWriter { result = 31 * result + (name != null ? name.hashCode() : 0); result = 31 * result + (phoneNumbers != null ? phoneNumbers.hashCode() : 0); result = 31 * result + (location != null ? location.hashCode() : 0); + result = 31 * result + (accounts != null ? accounts.hashCode() : 0); return result; } @Override public String toString() { return "User [id=" + id + ", name=" + name + ", phoneNumbers=" + phoneNumbers + ", location=" + location - + "]"; + + ", accounts=" + accounts + "]"; } public User cloneWithName(String name) { - return new User(id, name, phoneNumbers, location); + return new User(id, name, phoneNumbers, location, accounts); } } @@ -241,6 +263,16 @@ public class PhoneBookWriter { location.append("lat", user.getLocation().getLat()); } } + + if (user.getAccounts() != null) { + Group accounts = root.addGroup("accounts"); + for (Map.Entry<String, Double> account : user.getAccounts().entrySet()) { + Group kv = accounts.addGroup("key_value"); + kv.append("key", account.getKey()); + kv.append("value", account.getValue()); + } + } + return root; } @@ -249,7 +281,8 @@ public class PhoneBookWriter { getLong(root, "id"), getString(root, "name"), getPhoneNumbers(getGroup(root, "phoneNumbers")), - getLocation(getGroup(root, "location"))); + getLocation(getGroup(root, "location")), + getAccounts(getGroup(root, "accounts"))); } private static List<PhoneNumber> getPhoneNumbers(Group phoneNumbers) { @@ -271,6 +304,19 @@ public class PhoneBookWriter { return new Location(getDouble(location, "lon"), getDouble(location, "lat")); } + private static Map<String, Double> getAccounts(Group accounts) { + if (accounts == null) { + return null; + } + Map<String, Double> map = new HashMap<>(); + for (int i = 0, n = accounts.getFieldRepetitionCount("key_value"); i < n; ++i) { + Group kv = accounts.getGroup("key_value", i); + + map.put(getString(kv, "key"), getDouble(kv, "value")); + } + return map; + } + private static boolean isNull(Group group, String field) { // Use null value if the field is not in the group schema if (!group.getType().containsField(field)) { @@ -336,6 +382,10 @@ public class PhoneBookWriter { .withConf(conf) .withFilter(filter) .withAllocator(allocator) + .useBloomFilter(false) + .useDictionaryFilter(false) + .useStatsFilter(false) + .useColumnIndexFilter(false) .build(); } diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java index 0e81917c3..dedec409c 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/recordlevel/TestRecordLevelFilters.java @@ -20,6 +20,7 @@ package org.apache.parquet.filter2.recordlevel; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; @@ -31,6 +32,7 @@ import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.filter2.predicate.FilterApi.userDefined; import static org.junit.Assert.assertEquals; +import com.google.common.collect.ImmutableMap; import java.io.File; import java.io.IOException; import java.io.Serializable; @@ -59,21 +61,40 @@ public class TestRecordLevelFilters { public static List<User> makeUsers() { List<User> users = new ArrayList<User>(); - users.add(new User(17, null, null, null)); + users.add(new User( + 17, + null, + null, + null, + ImmutableMap.of( + "business", 1000.0D, + "personal", 500.0D))); users.add(new User(18, "bob", null, null)); - users.add(new User(19, "alice", new ArrayList<PhoneNumber>(), null)); + users.add(new User( + 19, + "alice", + new ArrayList<PhoneNumber>(), + null, + ImmutableMap.of( + "business", 2000.0D, + "retirement", 1000.0D))); users.add(new User(20, "thing1", Arrays.asList(new PhoneNumber(5555555555L, null)), null)); - users.add(new User(27, "thing2", Arrays.asList(new PhoneNumber(1111111111L, "home")), null)); + users.add(new User( + 27, + "thing2", + Arrays.asList(new PhoneNumber(1111111111L, "home"), new PhoneNumber(2222222222L, "cell")), + null)); users.add(new User( 28, "popular", Arrays.asList( new PhoneNumber(1111111111L, "home"), + new PhoneNumber(1111111111L, "apartment"), new PhoneNumber(2222222222L, null), new PhoneNumber(3333333333L, "mobile")), null)); @@ -127,6 +148,15 @@ public class TestRecordLevelFilters { } } + private static void assertPredicate(FilterPredicate predicate, long... expectedIds) throws IOException { + List<Group> found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.get(predicate)); + + assertEquals(expectedIds.length, found.size()); + for (int i = 0; i < expectedIds.length; i++) { + assertEquals(expectedIds[i], found.get(i).getLong("id", 0)); + } + } + @Test public void testNoFilter() throws Exception { List<Group> found = PhoneBookWriter.readFile(phonebookFile, FilterCompat.NOOP); @@ -181,6 +211,77 @@ public class TestRecordLevelFilters { assert (found.size() == 102); } + @Test + public void testArrayContains() throws Exception { + assertPredicate( + contains(eq(binaryColumn("phoneNumbers.phone.kind"), Binary.fromString("home"))), 27L, 28L, 30L); + } + + @Test + public void testArrayContainsSimpleAndFilter() throws Exception { + assertPredicate( + and( + contains(eq(longColumn("phoneNumbers.phone.number"), 1111111111L)), + contains(eq(longColumn("phoneNumbers.phone.number"), 3333333333L))), + 28L); + + assertPredicate( + and( + contains(eq(longColumn("phoneNumbers.phone.number"), 1111111111L)), + contains(eq(longColumn("phoneNumbers.phone.number"), -123L))) // Won't match + ); + } + + @Test + public void testArrayContainsNestedAndFilter() throws Exception { + assertPredicate( + and( + contains(eq(longColumn("phoneNumbers.phone.number"), 1111111111L)), + and( + contains(eq(longColumn("phoneNumbers.phone.number"), 2222222222L)), + contains(eq(longColumn("phoneNumbers.phone.number"), 3333333333L)))), + 28L); + } + + @Test + public void testArrayContainsSimpleOrFilter() throws Exception { + assertPredicate( + or( + contains(eq(longColumn("phoneNumbers.phone.number"), 5555555555L)), + contains(eq(longColumn("phoneNumbers.phone.number"), 2222222222L))), + 20L, + 27L, + 28L); + + assertPredicate( + or( + contains(eq(longColumn("phoneNumbers.phone.number"), 5555555555L)), + contains(eq(longColumn("phoneNumbers.phone.number"), -123L))), // Won't match + 20L); + } + + @Test + public void testArrayContainsNestedOrFilter() throws Exception { + assertPredicate( + or( + contains(eq(longColumn("phoneNumbers.phone.number"), 5555555555L)), + or( + contains(eq(longColumn("phoneNumbers.phone.number"), -10000000L)), // Won't be matched + contains(eq(longColumn("phoneNumbers.phone.number"), 2222222222L)))), + 20L, + 27L, + 28L); + } + + @Test + public void testMapContains() throws Exception { + // Test key predicate + assertPredicate(contains(eq(binaryColumn("accounts.key_value.key"), Binary.fromString("business"))), 17L, 19L); + + // Test value predicate + assertPredicate(contains(eq(doubleColumn("accounts.key_value.value"), 1000.0D)), 17L, 19L); + } + @Test public void testNameNotNull() throws Exception { BinaryColumn name = binaryColumn("name"); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java index 1615cf064..15d0a8ab1 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/filter2/statisticslevel/TestStatisticsFilter.java @@ -20,6 +20,7 @@ package org.apache.parquet.filter2.statisticslevel; import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.gt; @@ -49,6 +50,7 @@ import org.apache.parquet.column.statistics.DoubleStatistics; import org.apache.parquet.column.statistics.IntStatistics; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.filter2.predicate.LogicalInverseRewriter; +import org.apache.parquet.filter2.predicate.Operators; import org.apache.parquet.filter2.predicate.Operators.BinaryColumn; import org.apache.parquet.filter2.predicate.Operators.DoubleColumn; import org.apache.parquet.filter2.predicate.Operators.IntColumn; @@ -381,6 +383,38 @@ public class TestStatisticsFilter { Arrays.asList(getIntColumnMeta(statsSomeNulls, 177L), getDoubleColumnMeta(doubleStats, 177L)))); } + @Test + public void testContainsEqNonNull() { + assertTrue(canDrop(contains(eq(intColumn, 9)), columnMetas)); + assertFalse(canDrop(contains(eq(intColumn, 10)), columnMetas)); + assertFalse(canDrop(contains(eq(intColumn, 100)), columnMetas)); + assertTrue(canDrop(contains(eq(intColumn, 101)), columnMetas)); + + // drop columns of all nulls when looking for non-null value + assertTrue(canDrop(contains(eq(intColumn, 0)), nullColumnMetas)); + assertFalse(canDrop(contains(eq(intColumn, 50)), missingMinMaxColumnMetas)); + } + + @Test + public void testContainsAnd() { + Operators.Contains<Integer> yes = contains(eq(intColumn, 9)); + Operators.Contains<Double> no = contains(eq(doubleColumn, 50D)); + assertTrue(canDrop(and(yes, yes), columnMetas)); + assertTrue(canDrop(and(yes, no), columnMetas)); + assertTrue(canDrop(and(no, yes), columnMetas)); + assertFalse(canDrop(and(no, no), columnMetas)); + } + + @Test + public void testContainsOr() { + Operators.Contains<Integer> yes = contains(eq(intColumn, 9)); + Operators.Contains<Double> no = contains(eq(doubleColumn, 50D)); + assertTrue(canDrop(or(yes, yes), columnMetas)); + assertFalse(canDrop(or(yes, no), columnMetas)); + assertFalse(canDrop(or(no, yes), columnMetas)); + assertFalse(canDrop(or(no, no), columnMetas)); + } + @Test public void testAnd() { FilterPredicate yes = eq(intColumn, 9); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java index 29c5d6f58..651184a0d 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestBloomFiltering.java @@ -19,11 +19,14 @@ package org.apache.parquet.hadoop; +import static org.apache.parquet.filter2.predicate.FilterApi.and; import static org.apache.parquet.filter2.predicate.FilterApi.binaryColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.contains; import static org.apache.parquet.filter2.predicate.FilterApi.doubleColumn; import static org.apache.parquet.filter2.predicate.FilterApi.eq; import static org.apache.parquet.filter2.predicate.FilterApi.in; import static org.apache.parquet.filter2.predicate.FilterApi.longColumn; +import static org.apache.parquet.filter2.predicate.FilterApi.or; import static org.apache.parquet.hadoop.ParquetFileWriter.Mode.OVERWRITE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; @@ -40,6 +43,7 @@ import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Random; import java.util.Set; import java.util.function.Predicate; @@ -120,7 +124,8 @@ public class TestBloomFiltering { List<PhoneBookWriter.User> users = new ArrayList<>(); List<String> names = generateNames(rowCount); for (int i = 0; i < rowCount; ++i) { - users.add(new PhoneBookWriter.User(i, names.get(i), generatePhoneNumbers(), generateLocation(i, rowCount))); + users.add( + new PhoneBookWriter.User(i, names.get(i), generatePhoneNumbers(i), generateLocation(i, rowCount))); } return users; } @@ -171,12 +176,13 @@ public class TestBloomFiltering { names.add("len"); } for (int i = 0; i < rowCount; ++i) { - users.add(new PhoneBookWriter.User(i, names.get(i), generatePhoneNumbers(), generateLocation(i, rowCount))); + users.add( + new PhoneBookWriter.User(i, names.get(i), generatePhoneNumbers(i), generateLocation(i, rowCount))); } return users; } - private static List<PhoneBookWriter.PhoneNumber> generatePhoneNumbers() { + private static List<PhoneBookWriter.PhoneNumber> generatePhoneNumbers(int index) { int length = RANDOM.nextInt(5) - 1; if (length < 0) { return null; @@ -184,8 +190,8 @@ public class TestBloomFiltering { List<PhoneBookWriter.PhoneNumber> phoneNumbers = new ArrayList<>(length); for (int i = 0; i < length; ++i) { // 6 digits numbers - long number = Math.abs(RANDOM.nextLong() % 900000) + 100000; - phoneNumbers.add(new PhoneBookWriter.PhoneNumber(number, PHONE_KINDS[RANDOM.nextInt(PHONE_KINDS.length)])); + phoneNumbers.add( + new PhoneBookWriter.PhoneNumber(500L % index, PHONE_KINDS[RANDOM.nextInt(PHONE_KINDS.length)])); } return phoneNumbers; } @@ -318,12 +324,17 @@ public class TestBloomFiltering { .withBloomFilterEnabled("name", true) .withBloomFilterCandidateNumber("name", 10) .withBloomFilterEnabled("id", true) - .withBloomFilterCandidateNumber("id", 10); + .withBloomFilterCandidateNumber("id", 10) + .withDictionaryEncoding("phoneNumbers.phone.number", false) + .withBloomFilterEnabled("phoneNumbers.phone.number", true) + .withBloomFilterCandidateNumber("phoneNumbers.phone.number", 10); } else { writeBuilder .withBloomFilterNDV("location.lat", 10000L) .withBloomFilterNDV("name", 10000L) - .withBloomFilterNDV("id", 10000L); + .withBloomFilterNDV("id", 10000L) + .withDictionaryEncoding("phoneNumbers.phone.number", false) + .withBloomFilterNDV("phoneNumbers.phone.number", 10000L); } PhoneBookWriter.write(writeBuilder, DATA); } @@ -398,6 +409,38 @@ public class TestBloomFiltering { eq(doubleColumn("location.lat"), 99.9)); } + @Test + public void testContainsEqFiltering() throws IOException { + assertCorrectFiltering( + record -> Optional.ofNullable(record.getPhoneNumbers()) + .map(numbers -> numbers.stream().anyMatch(n -> n.getNumber() == 250L)) + .orElse(false), + contains(eq(longColumn("phoneNumbers.phone.number"), 250L))); + } + + @Test + public void testContainsOrFiltering() throws IOException { + assertCorrectFiltering( + record -> Optional.ofNullable(record.getPhoneNumbers()) + .map(numbers -> numbers.stream().anyMatch(n -> n.getNumber() == 250L || n.getNumber() == 50L)) + .orElse(false), + or( + contains(eq(longColumn("phoneNumbers.phone.number"), 250L)), + contains(eq(longColumn("phoneNumbers.phone.number"), 50L)))); + } + + @Test + public void testContainsAndFiltering() throws IOException { + assertCorrectFiltering( + record -> Optional.ofNullable(record.getPhoneNumbers()) + .map(numbers -> numbers.stream().anyMatch(n -> n.getNumber() == 10L) + && numbers.stream().anyMatch(n -> n.getNumber() == 5L)) + .orElse(false), + and( + contains(eq(longColumn("phoneNumbers.phone.number"), 10L)), + contains(eq(longColumn("phoneNumbers.phone.number"), 5L)))); + } + @Test public void checkBloomFilterSize() throws IOException { FileDecryptionProperties fileDecryptionProperties = getFileDecryptionProperties(); diff --git a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriterError.java b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriterError.java index 51fa90e1c..51f8a7dd6 100644 --- a/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriterError.java +++ b/parquet-hadoop/src/test/java/org/apache/parquet/hadoop/TestParquetWriterError.java @@ -23,7 +23,9 @@ import java.lang.reflect.Field; import java.nio.ByteBuffer; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Random; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.bytes.DirectByteBufferAllocator; @@ -109,18 +111,23 @@ public class TestParquetWriterError { location = null; } List<PhoneBookWriter.PhoneNumber> phoneNumbers; + Map<String, Double> accounts; if (RANDOM.nextDouble() < .1) { phoneNumbers = null; + accounts = null; } else { int n = RANDOM.nextInt(4); phoneNumbers = new ArrayList<>(n); + accounts = new HashMap<>(); for (int i = 0; i < n; ++i) { String kind = RANDOM.nextDouble() < .1 ? null : "kind" + RANDOM.nextInt(5); phoneNumbers.add(new PhoneBookWriter.PhoneNumber(RANDOM.nextInt(), kind)); + accounts.put("Account " + i, (double) i); } } String name = RANDOM.nextDouble() < .1 ? null : "name" + RANDOM.nextLong(); - PhoneBookWriter.User user = new PhoneBookWriter.User(RANDOM.nextLong(), name, phoneNumbers, location); + PhoneBookWriter.User user = + new PhoneBookWriter.User(RANDOM.nextLong(), name, phoneNumbers, location, accounts); return PhoneBookWriter.groupFromUser(user); }