http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingKeyGroupedIterator.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingKeyGroupedIterator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingKeyGroupedIterator.java new file mode 100644 index 0000000..4dc9dd3 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingKeyGroupedIterator.java @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.util; + +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.MutableObjectIterator; +import org.apache.flink.util.TraversableOnceException; + +/** + * The KeyValueIterator returns a key and all values that belong to the key (share the same key). + * + */ +public final class ReusingKeyGroupedIterator<E> { + + private final MutableObjectIterator<E> iterator; + + private final TypeSerializer<E> serializer; + + private final TypeComparator<E> comparator; + + private E reuse; + + private E current; + + private E lookahead; + + private ValuesIterator valuesIterator; + + private boolean lookAheadHasNext; + + private boolean done; + + /** + * Initializes the KeyGroupedIterator. It requires an iterator which returns its result + * sorted by the key fields. + * + * @param iterator An iterator over records, which are sorted by the key fields, in any order. + * @param serializer The serializer for the data type iterated over. + * @param comparator The comparator for the data type iterated over. + */ + public ReusingKeyGroupedIterator(MutableObjectIterator<E> iterator, TypeSerializer<E> + serializer, TypeComparator<E> comparator) + { + if (iterator == null || serializer == null || comparator == null) { + throw new NullPointerException(); + } + + this.iterator = iterator; + this.serializer = serializer; + this.comparator = comparator; + this.reuse = this.serializer.createInstance(); + } + + /** + * Moves the iterator to the next key. This method may skip any values that have not yet been returned by the + * iterator created by the {@link #getValues()} method. Hence, if called multiple times it "removes" pairs. + * + * @return true if the input iterator has an other group of key-value pairs that share the same key. + */ + public boolean nextKey() throws IOException + { + // first element (or empty) + if (this.current == null) { + if (this.done) { + this.valuesIterator = null; + return false; + } + this.current = this.reuse; + if ((this.current = this.iterator.next(this.current)) != null) { + this.comparator.setReference(this.current); + this.lookAheadHasNext = false; + this.valuesIterator = new ValuesIterator(); + this.valuesIterator.currentIsUnconsumed = true; + return true; + } else { + // empty input, set everything null + this.valuesIterator = null; + this.current = null; + this.done = true; + return false; + } + } + + this.valuesIterator.iteratorAvailable = true; + + // Whole value-iterator was read and a new key is available. + if (this.lookAheadHasNext) { + this.lookAheadHasNext = false; + this.current = this.lookahead; + this.lookahead = null; + this.comparator.setReference(this.current); + this.valuesIterator.currentIsUnconsumed = true; + return true; + } + + // try to move to next key. + // Required if user code / reduce() method did not read the whole value iterator. + while (true) { + if (!this.done && ((this.current = this.iterator.next(this.current)) != null)) { + if (!this.comparator.equalToReference(this.current)) { + // the keys do not match, so we have a new group. store the current keys + this.comparator.setReference(this.current); + this.lookAheadHasNext = false; + this.valuesIterator.currentIsUnconsumed = true; + return true; + } + } + else { + this.valuesIterator = null; + this.current = null; + this.done = true; + return false; + } + } + } + + public TypeComparator<E> getComparatorWithCurrentReference() { + return this.comparator; + } + + public E getCurrent() { + return this.current; + } + + /** + * Returns an iterator over all values that belong to the current key. The iterator is initially <code>null</code> + * (before the first call to {@link #nextKey()} and after all keys are consumed. In general, this method returns + * always a non-null value, if a previous call to {@link #nextKey()} return <code>true</code>. + * + * @return Iterator over all values that belong to the current key. + */ + public ValuesIterator getValues() { + return this.valuesIterator; + } + + // -------------------------------------------------------------------------------------------- + + public final class ValuesIterator implements Iterator<E>, Iterable<E> { + + private final TypeSerializer<E> serializer = ReusingKeyGroupedIterator.this.serializer; + private final TypeComparator<E> comparator = ReusingKeyGroupedIterator.this.comparator; + + private E staging = this.serializer.createInstance(); + private boolean currentIsUnconsumed = false; + + private boolean iteratorAvailable = true; + + private ValuesIterator() {} + + @Override + public boolean hasNext() { + if (ReusingKeyGroupedIterator.this.current == null || ReusingKeyGroupedIterator.this.lookAheadHasNext) { + return false; + } + if (this.currentIsUnconsumed) { + return true; + } + + try { + // read the next value into the staging record to make sure we keep the + // current as it is in case the key changed + E stagingStaging = ReusingKeyGroupedIterator.this.iterator.next(this.staging); + if (stagingStaging != null) { + this.staging = stagingStaging; + if (this.comparator.equalToReference(this.staging)) { + // same key, next value is in staging, so exchange staging with current + final E tmp = this.staging; + this.staging = ReusingKeyGroupedIterator.this.current; + ReusingKeyGroupedIterator.this.current = tmp; + this.currentIsUnconsumed = true; + return true; + } else { + // moved to the next key, no more values here + ReusingKeyGroupedIterator.this.lookAheadHasNext = true; + ReusingKeyGroupedIterator.this.lookahead = this.staging; + this.staging = ReusingKeyGroupedIterator.this.current; + return false; + } + } + else { + // backing iterator is consumed + ReusingKeyGroupedIterator.this.done = true; + return false; + } + } + catch (IOException ioex) { + throw new RuntimeException("An error occurred while reading the next record: " + + ioex.getMessage(), ioex); + } + } + + /** + * Prior to call this method, call hasNext() once! + */ + @Override + public E next() { + if (this.currentIsUnconsumed || hasNext()) { + this.currentIsUnconsumed = false; + return ReusingKeyGroupedIterator.this.current; + } else { + throw new NoSuchElementException(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator<E> iterator() { + if (iteratorAvailable) { + iteratorAvailable = false; + return this; + } + else { + throw new TraversableOnceException(); + } + } + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingMutableToRegularIteratorWrapper.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingMutableToRegularIteratorWrapper.java b/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingMutableToRegularIteratorWrapper.java new file mode 100644 index 0000000..b9e3e74 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/util/ReusingMutableToRegularIteratorWrapper.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.runtime.util; + +import java.io.IOException; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.util.MutableObjectIterator; +import org.apache.flink.util.TraversableOnceException; + +/** + * This class wraps a {@link MutableObjectIterator} into a regular {@link Iterator}. + * Internally, it uses two record instances which it uses alternating. That way, + * whenever hasNext() returns (possibly with false), the previous obtained record is + * still valid and cannot have been overwritten internally. + */ +public class ReusingMutableToRegularIteratorWrapper<T> implements Iterator<T>, Iterable<T> { + + private final MutableObjectIterator<T> source; + + private T current, next; + + private boolean currentIsAvailable; + + private boolean iteratorAvailable = true; + + public ReusingMutableToRegularIteratorWrapper(MutableObjectIterator<T> source, + TypeSerializer<T> serializer) { + this.source = source; + this.current = serializer.createInstance(); + this.next = serializer.createInstance(); + } + + @Override + public boolean hasNext() { + if (currentIsAvailable) { + return true; + } else { + try { + // we always use two records such that whenever hasNext() returns (possibly with false), + // the previous record is always still valid. + if ((next = source.next(next)) != null) { + + T tmp = current; + current = next; + next = tmp; + + currentIsAvailable = true; + return true; + } else { + return false; + } + } catch (IOException ioex) { + throw new RuntimeException("Error reading next record: " + ioex.getMessage(), ioex); + } + } + } + + @Override + public T next() { + if (currentIsAvailable || hasNext()) { + currentIsAvailable = false; + return current; + } else { + throw new NoSuchElementException(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + + @Override + public Iterator<T> iterator() { + if (iteratorAvailable) { + iteratorAvailable = false; + return this; + } + else { + throw new TraversableOnceException(); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CachedMatchTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CachedMatchTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CachedMatchTaskTest.java index 262a4e9..a57287a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CachedMatchTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CachedMatchTaskTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.functions.RichFlatJoinFunction; import org.apache.flink.api.common.typeutils.record.RecordComparator; @@ -56,8 +57,8 @@ public class CachedMatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, private final List<Record> outList = new ArrayList<Record>(); - public CachedMatchTaskTest() { - super(HASH_MEM, 2, SORT_MEM); + public CachedMatchTaskTest(ExecutionConfig config) { + super(config, HASH_MEM, 2, SORT_MEM); } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskExternalITCase.java index 8e7d9d3..7f96954 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskExternalITCase.java @@ -18,6 +18,7 @@ package org.apache.flink.runtime.operators; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; @@ -45,8 +46,8 @@ public class CoGroupTaskExternalITCase extends DriverTestBase<CoGroupFunction<Re private final CountingOutputCollector output = new CountingOutputCollector(); - public CoGroupTaskExternalITCase() { - super(0, 2, SORT_MEM); + public CoGroupTaskExternalITCase(ExecutionConfig config) { + super(config, 0, 2, SORT_MEM); } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java index 46007ab..9c0f075 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CoGroupTaskTest.java @@ -20,6 +20,7 @@ package org.apache.flink.runtime.operators; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; @@ -52,8 +53,8 @@ public class CoGroupTaskTest extends DriverTestBase<CoGroupFunction<Record, Reco private final CountingOutputCollector output = new CountingOutputCollector(); - public CoGroupTaskTest() { - super(0, 2, SORT_MEM); + public CoGroupTaskTest(ExecutionConfig config) { + super(config, 0, 2, SORT_MEM); } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java index decf358..5e2f3ca 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java @@ -22,6 +22,7 @@ package org.apache.flink.runtime.operators; import java.util.ArrayList; import java.util.HashMap; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.typeutils.record.RecordComparator; @@ -46,8 +47,8 @@ public class CombineTaskExternalITCase extends DriverTestBase<RichGroupReduceFun private final RecordComparator comparator = new RecordComparator( new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class }); - public CombineTaskExternalITCase() { - super(COMBINE_MEM, 0); + public CombineTaskExternalITCase(ExecutionConfig config) { + super(config, COMBINE_MEM, 0); combine_frac = (double)COMBINE_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java index 1f18917..515ce76 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.operators; import java.util.ArrayList; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.typeutils.record.RecordComparator; @@ -49,8 +50,8 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco private final RecordComparator comparator = new RecordComparator( new int[]{0}, (Class<? extends Key<?>>[])new Class[]{ IntValue.class }); - public CombineTaskTest() { - super(COMBINE_MEM, 0); + public CombineTaskTest(ExecutionConfig config) { + super(config, COMBINE_MEM, 0); combine_frac = (double)COMBINE_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskExternalITCase.java index 1924489..b678ed1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskExternalITCase.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.operators; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.CrossFunction; @@ -37,8 +38,8 @@ public class CrossTaskExternalITCase extends DriverTestBase<CrossFunction<Record private final CountingOutputCollector output = new CountingOutputCollector(); - public CrossTaskExternalITCase() { - super(CROSS_MEM, 0); + public CrossTaskExternalITCase(ExecutionConfig config) { + super(config, CROSS_MEM, 0); cross_frac = (double)CROSS_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskTest.java index bca6896..4c27a68 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CrossTaskTest.java @@ -20,6 +20,7 @@ package org.apache.flink.runtime.operators; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.CrossFunction; import org.apache.flink.runtime.operators.testutils.DelayingInfinitiveInputIterator; @@ -39,8 +40,8 @@ public class CrossTaskTest extends DriverTestBase<CrossFunction<Record, Record, private final CountingOutputCollector output = new CountingOutputCollector(); - public CrossTaskTest() { - super(CROSS_MEM, 0); + public CrossTaskTest(ExecutionConfig config) { + super(config, CROSS_MEM, 0); cross_frac = (double)CROSS_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MapTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MapTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MapTaskTest.java index 4997777..bfc6c44 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MapTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MapTaskTest.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.operators; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.flink.api.common.functions.GenericCollectorMap; @@ -45,8 +46,8 @@ public class MapTaskTest extends DriverTestBase<GenericCollectorMap<Record, Reco private final CountingOutputCollector output = new CountingOutputCollector(); - public MapTaskTest() { - super(0, 0); + public MapTaskTest(ExecutionConfig config) { + super(config, 0, 0); } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskExternalITCase.java index 1e0e882..29be8f8 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskExternalITCase.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.operators; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.typeutils.record.RecordComparator; @@ -55,8 +56,8 @@ public class MatchTaskExternalITCase extends DriverTestBase<FlatJoinFunction<Rec private final CountingOutputCollector output = new CountingOutputCollector(); - public MatchTaskExternalITCase() { - super(HASH_MEM, 2, SORT_MEM); + public MatchTaskExternalITCase(ExecutionConfig config) { + super(config, HASH_MEM, 2, SORT_MEM); bnljn_frac = (double)BNLJN_MEM/this.getMemoryManager().getMemorySize(); hash_frac = (double)HASH_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskTest.java index 584bc02..553212a 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/MatchTaskTest.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.FlatJoinFunction; import org.apache.flink.api.common.typeutils.record.RecordComparator; import org.apache.flink.api.common.typeutils.record.RecordPairComparatorFactory; @@ -66,8 +67,8 @@ public class MatchTaskTest extends DriverTestBase<FlatJoinFunction<Record, Recor private final List<Record> outList = new ArrayList<Record>(); - public MatchTaskTest() { - super(HASH_MEM, NUM_SORTER, SORT_MEM); + public MatchTaskTest(ExecutionConfig config) { + super(config, HASH_MEM, NUM_SORTER, SORT_MEM); bnljn_frac = (double)BNLJN_MEM/this.getMemoryManager().getMemorySize(); hash_frac = (double)HASH_MEM/this.getMemoryManager().getMemorySize(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java index eb26f5b..d83e92e 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskExternalITCase.java @@ -21,6 +21,7 @@ package org.apache.flink.runtime.operators; import java.util.ArrayList; import java.util.List; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,8 +49,8 @@ public class ReduceTaskExternalITCase extends DriverTestBase<RichGroupReduceFunc private final List<Record> outList = new ArrayList<Record>(); - public ReduceTaskExternalITCase() { - super(0, 1, 3*1024*1024); + public ReduceTaskExternalITCase(ExecutionConfig config) { + super(config, 0, 1, 3*1024*1024); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java index c5a6762..964f646 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/ReduceTaskTest.java @@ -22,6 +22,7 @@ import java.util.ArrayList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.flink.api.common.ExecutionConfig; import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -52,8 +53,8 @@ public class ReduceTaskTest extends DriverTestBase<RichGroupReduceFunction<Recor private final List<Record> outList = new ArrayList<Record>(); - public ReduceTaskTest() { - super(0, 1, 3*1024*1024); + public ReduceTaskTest(ExecutionConfig config) { + super(config, 0, 1, 3*1024*1024); } @Test http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/TestTaskContext.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/TestTaskContext.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/TestTaskContext.java index e5ece3f..b614709 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/TestTaskContext.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/drivers/TestTaskContext.java @@ -19,6 +19,7 @@ package org.apache.flink.runtime.operators.drivers; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.typeutils.TypeComparator; import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerFactory; @@ -59,6 +60,8 @@ public class TestTaskContext<S, T> implements PactTaskContext<S, T> { private MemoryManager memoryManager; + private ExecutionConfig executionConfig = new ExecutionConfig(); + // -------------------------------------------------------------------------------------------- // Constructors // -------------------------------------------------------------------------------------------- @@ -133,6 +136,11 @@ public class TestTaskContext<S, T> implements PactTaskContext<S, T> { } @Override + public ExecutionConfig getExecutionConfig() { + return executionConfig; + } + + @Override public ClassLoader getUserCodeClassLoader() { return getClass().getClassLoader(); } http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/HashMatchIteratorITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/HashMatchIteratorITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/HashMatchIteratorITCase.java deleted file mode 100644 index 21e686d..0000000 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/HashMatchIteratorITCase.java +++ /dev/null @@ -1,778 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.flink.runtime.operators.hash; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; - -import org.apache.flink.api.common.functions.AbstractRichFunction; -import org.apache.flink.api.common.functions.FlatJoinFunction; -import org.apache.flink.api.common.typeutils.TypeComparator; -import org.apache.flink.api.common.typeutils.TypePairComparator; -import org.apache.flink.api.common.typeutils.TypeSerializer; -import org.apache.flink.api.common.typeutils.record.RecordComparator; -import org.apache.flink.api.common.typeutils.record.RecordPairComparator; -import org.apache.flink.api.common.typeutils.record.RecordSerializer; -import org.apache.flink.api.java.record.functions.JoinFunction; -import org.apache.flink.runtime.io.disk.iomanager.IOManager; -import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; -import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; -import org.apache.flink.runtime.memorymanager.DefaultMemoryManager; -import org.apache.flink.runtime.memorymanager.MemoryManager; -import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector; -import org.apache.flink.runtime.operators.testutils.DummyInvokable; -import org.apache.flink.runtime.operators.testutils.TestData; -import org.apache.flink.runtime.operators.testutils.UniformIntPairGenerator; -import org.apache.flink.runtime.operators.testutils.UnionIterator; -import org.apache.flink.runtime.operators.testutils.TestData.Generator; -import org.apache.flink.runtime.operators.testutils.TestData.Generator.KeyMode; -import org.apache.flink.runtime.operators.testutils.TestData.Generator.ValueMode; -import org.apache.flink.runtime.operators.testutils.types.IntPair; -import org.apache.flink.runtime.operators.testutils.types.IntPairComparator; -import org.apache.flink.runtime.operators.testutils.types.IntPairSerializer; -import org.apache.flink.types.IntValue; -import org.apache.flink.types.NullKeyFieldException; -import org.apache.flink.types.Record; -import org.apache.flink.types.Value; -import org.apache.flink.util.Collector; -import org.apache.flink.util.MutableObjectIterator; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -@SuppressWarnings({"serial", "deprecation"}) -public class HashMatchIteratorITCase { - - private static final int MEMORY_SIZE = 16000000; // total memory - - private static final int INPUT_1_SIZE = 20000; - private static final int INPUT_2_SIZE = 1000; - - private static final long SEED1 = 561349061987311L; - private static final long SEED2 = 231434613412342L; - - private final AbstractInvokable parentTask = new DummyInvokable(); - - private IOManager ioManager; - private MemoryManager memoryManager; - - private TypeSerializer<Record> recordSerializer; - private TypeComparator<Record> record1Comparator; - private TypeComparator<Record> record2Comparator; - private TypePairComparator<Record, Record> recordPairComparator; - - private TypeSerializer<IntPair> pairSerializer; - private TypeComparator<IntPair> pairComparator; - private TypePairComparator<IntPair, Record> pairRecordPairComparator; - private TypePairComparator<Record, IntPair> recordPairPairComparator; - - - @SuppressWarnings("unchecked") - @Before - public void beforeTest() { - this.recordSerializer = RecordSerializer.get(); - - this.record1Comparator = new RecordComparator(new int[] {0}, new Class[] {TestData.Key.class}); - this.record2Comparator = new RecordComparator(new int[] {0}, new Class[] {TestData.Key.class}); - - this.recordPairComparator = new RecordPairComparator(new int[] {0}, new int[] {0}, new Class[] {TestData.Key.class}); - - this.pairSerializer = new IntPairSerializer(); - this.pairComparator = new IntPairComparator(); - this.pairRecordPairComparator = new IntPairRecordPairComparator(); - this.recordPairPairComparator = new RecordIntPairPairComparator(); - - this.memoryManager = new DefaultMemoryManager(MEMORY_SIZE, 1); - this.ioManager = new IOManagerAsync(); - } - - @After - public void afterTest() { - if (this.ioManager != null) { - this.ioManager.shutdown(); - if (!this.ioManager.isProperlyShutDown()) { - Assert.fail("I/O manager failed to properly shut down."); - } - this.ioManager = null; - } - - if (this.memoryManager != null) { - Assert.assertTrue("Memory Leak: Not all memory has been returned to the memory manager.", - this.memoryManager.verifyEmpty()); - this.memoryManager.shutdown(); - this.memoryManager = null; - } - } - - - @Test - public void testBuildFirst() { - try { - Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - - final TestData.GeneratorIterator input1 = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); - final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - // collect expected data - final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( - collectRecordData(input1), - collectRecordData(input2)); - - final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - // reset the generators - generator1.reset(); - generator2.reset(); - input1.reset(); - input2.reset(); - - // compare with iterator values - BuildFirstHashMatchIterator<Record, Record, Record> iterator = - new BuildFirstHashMatchIterator<Record, Record, Record>( - input1, input2, this.recordSerializer, this.record1Comparator, - this.recordSerializer, this.record2Comparator, this.recordPairComparator, - this.memoryManager, ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - @Test - public void testBuildFirstWithHighNumberOfCommonKeys() - { - // the size of the left and right inputs - final int INPUT_1_SIZE = 200; - final int INPUT_2_SIZE = 100; - - final int INPUT_1_DUPLICATES = 10; - final int INPUT_2_DUPLICATES = 2000; - final int DUPLICATE_KEY = 13; - - try { - Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - - final TestData.GeneratorIterator gen1Iter = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); - final TestData.GeneratorIterator gen2Iter = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - final TestData.ConstantValueIterator const1Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "LEFT String for Duplicate Keys", INPUT_1_DUPLICATES); - final TestData.ConstantValueIterator const2Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "RIGHT String for Duplicate Keys", INPUT_2_DUPLICATES); - - final List<MutableObjectIterator<Record>> inList1 = new ArrayList<MutableObjectIterator<Record>>(); - inList1.add(gen1Iter); - inList1.add(const1Iter); - - final List<MutableObjectIterator<Record>> inList2 = new ArrayList<MutableObjectIterator<Record>>(); - inList2.add(gen2Iter); - inList2.add(const2Iter); - - MutableObjectIterator<Record> input1 = new UnionIterator<Record>(inList1); - MutableObjectIterator<Record> input2 = new UnionIterator<Record>(inList2); - - - // collect expected data - final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( - collectRecordData(input1), - collectRecordData(input2)); - - // re-create the whole thing for actual processing - - // reset the generators and iterators - generator1.reset(); - generator2.reset(); - const1Iter.reset(); - const2Iter.reset(); - gen1Iter.reset(); - gen2Iter.reset(); - - inList1.clear(); - inList1.add(gen1Iter); - inList1.add(const1Iter); - - inList2.clear(); - inList2.add(gen2Iter); - inList2.add(const2Iter); - - input1 = new UnionIterator<Record>(inList1); - input2 = new UnionIterator<Record>(inList2); - - final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - BuildFirstHashMatchIterator<Record, Record, Record> iterator = - new BuildFirstHashMatchIterator<Record, Record, Record>( - input1, input2, this.recordSerializer, this.record1Comparator, - this.recordSerializer, this.record2Comparator, this.recordPairComparator, - this.memoryManager, ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - @Test - public void testBuildSecond() { - try { - Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - - final TestData.GeneratorIterator input1 = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); - final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - // collect expected data - final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( - collectRecordData(input1), - collectRecordData(input2)); - - final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - // reset the generators - generator1.reset(); - generator2.reset(); - input1.reset(); - input2.reset(); - - // compare with iterator values - BuildSecondHashMatchIterator<Record, Record, Record> iterator = - new BuildSecondHashMatchIterator<Record, Record, Record>( - input1, input2, this.recordSerializer, this.record1Comparator, - this.recordSerializer, this.record2Comparator, this.recordPairComparator, - this.memoryManager, ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - @Test - public void testBuildSecondWithHighNumberOfCommonKeys() - { - // the size of the left and right inputs - final int INPUT_1_SIZE = 200; - final int INPUT_2_SIZE = 100; - - final int INPUT_1_DUPLICATES = 10; - final int INPUT_2_DUPLICATES = 2000; - final int DUPLICATE_KEY = 13; - - try { - Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - - final TestData.GeneratorIterator gen1Iter = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); - final TestData.GeneratorIterator gen2Iter = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - final TestData.ConstantValueIterator const1Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "LEFT String for Duplicate Keys", INPUT_1_DUPLICATES); - final TestData.ConstantValueIterator const2Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "RIGHT String for Duplicate Keys", INPUT_2_DUPLICATES); - - final List<MutableObjectIterator<Record>> inList1 = new ArrayList<MutableObjectIterator<Record>>(); - inList1.add(gen1Iter); - inList1.add(const1Iter); - - final List<MutableObjectIterator<Record>> inList2 = new ArrayList<MutableObjectIterator<Record>>(); - inList2.add(gen2Iter); - inList2.add(const2Iter); - - MutableObjectIterator<Record> input1 = new UnionIterator<Record>(inList1); - MutableObjectIterator<Record> input2 = new UnionIterator<Record>(inList2); - - - // collect expected data - final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( - collectRecordData(input1), - collectRecordData(input2)); - - // re-create the whole thing for actual processing - - // reset the generators and iterators - generator1.reset(); - generator2.reset(); - const1Iter.reset(); - const2Iter.reset(); - gen1Iter.reset(); - gen2Iter.reset(); - - inList1.clear(); - inList1.add(gen1Iter); - inList1.add(const1Iter); - - inList2.clear(); - inList2.add(gen2Iter); - inList2.add(const2Iter); - - input1 = new UnionIterator<Record>(inList1); - input2 = new UnionIterator<Record>(inList2); - - final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - BuildSecondHashMatchIterator<Record, Record, Record> iterator = - new BuildSecondHashMatchIterator<Record, Record, Record>( - input1, input2, this.recordSerializer, this.record1Comparator, - this.recordSerializer, this.record2Comparator, this.recordPairComparator, - this.memoryManager, ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - @Test - public void testBuildFirstWithMixedDataTypes() { - try { - MutableObjectIterator<IntPair> input1 = new UniformIntPairGenerator(500, 40, false); - - final Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - // collect expected data - final Map<TestData.Key, Collection<RecordIntPairMatch>> expectedMatchesMap = matchRecordIntPairValues( - collectIntPairData(input1), - collectRecordData(input2)); - - final FlatJoinFunction<IntPair, Record, Record> matcher = new RecordIntPairMatchRemovingMatcher(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - // reset the generators - input1 = new UniformIntPairGenerator(500, 40, false); - generator2.reset(); - input2.reset(); - - // compare with iterator values - BuildSecondHashMatchIterator<IntPair, Record, Record> iterator = - new BuildSecondHashMatchIterator<IntPair, Record, Record>( - input1, input2, this.pairSerializer, this.pairComparator, - this.recordSerializer, this.record2Comparator, this.pairRecordPairComparator, - this.memoryManager, this.ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordIntPairMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - @Test - public void testBuildSecondWithMixedDataTypes() { - try { - MutableObjectIterator<IntPair> input1 = new UniformIntPairGenerator(500, 40, false); - - final Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); - final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); - - // collect expected data - final Map<TestData.Key, Collection<RecordIntPairMatch>> expectedMatchesMap = matchRecordIntPairValues( - collectIntPairData(input1), - collectRecordData(input2)); - - final FlatJoinFunction<IntPair, Record, Record> matcher = new RecordIntPairMatchRemovingMatcher(expectedMatchesMap); - final Collector<Record> collector = new DiscardingOutputCollector<Record>(); - - // reset the generators - input1 = new UniformIntPairGenerator(500, 40, false); - generator2.reset(); - input2.reset(); - - // compare with iterator values - BuildFirstHashMatchIterator<IntPair, Record, Record> iterator = - new BuildFirstHashMatchIterator<IntPair, Record, Record>( - input1, input2, this.pairSerializer, this.pairComparator, - this.recordSerializer, this.record2Comparator, this.recordPairPairComparator, - this.memoryManager, this.ioManager, this.parentTask, 1.0); - - iterator.open(); - - while (iterator.callWithNextKey(matcher, collector)); - - iterator.close(); - - // assert that each expected match was seen - for (Entry<TestData.Key, Collection<RecordIntPairMatch>> entry : expectedMatchesMap.entrySet()) { - if (!entry.getValue().isEmpty()) { - Assert.fail("Collection for key " + entry.getKey() + " is not empty"); - } - } - } - catch (Exception e) { - e.printStackTrace(); - Assert.fail("An exception occurred during the test: " + e.getMessage()); - } - } - - // -------------------------------------------------------------------------------------------- - // Utilities - // -------------------------------------------------------------------------------------------- - - - - static Map<TestData.Key, Collection<RecordMatch>> matchRecordValues( - Map<TestData.Key, Collection<TestData.Value>> leftMap, - Map<TestData.Key, Collection<TestData.Value>> rightMap) - { - Map<TestData.Key, Collection<RecordMatch>> map = new HashMap<TestData.Key, Collection<RecordMatch>>(); - - for (TestData.Key key : leftMap.keySet()) { - Collection<TestData.Value> leftValues = leftMap.get(key); - Collection<TestData.Value> rightValues = rightMap.get(key); - - if (rightValues == null) { - continue; - } - - if (!map.containsKey(key)) { - map.put(key, new ArrayList<RecordMatch>()); - } - - Collection<RecordMatch> matchedValues = map.get(key); - - for (TestData.Value leftValue : leftValues) { - for (TestData.Value rightValue : rightValues) { - matchedValues.add(new RecordMatch(leftValue, rightValue)); - } - } - } - - return map; - } - - static Map<TestData.Key, Collection<RecordIntPairMatch>> matchRecordIntPairValues( - Map<Integer, Collection<Integer>> leftMap, - Map<TestData.Key, Collection<TestData.Value>> rightMap) - { - final Map<TestData.Key, Collection<RecordIntPairMatch>> map = new HashMap<TestData.Key, Collection<RecordIntPairMatch>>(); - - for (Integer i : leftMap.keySet()) { - - final TestData.Key key = new TestData.Key(i.intValue()); - - final Collection<Integer> leftValues = leftMap.get(i); - final Collection<TestData.Value> rightValues = rightMap.get(key); - - if (rightValues == null) { - continue; - } - - if (!map.containsKey(key)) { - map.put(key, new ArrayList<RecordIntPairMatch>()); - } - - final Collection<RecordIntPairMatch> matchedValues = map.get(key); - - for (Integer v : leftValues) { - for (TestData.Value val : rightValues) { - matchedValues.add(new RecordIntPairMatch(v, val)); - } - } - } - - return map; - } - - - static Map<TestData.Key, Collection<TestData.Value>> collectRecordData(MutableObjectIterator<Record> iter) - throws Exception - { - Map<TestData.Key, Collection<TestData.Value>> map = new HashMap<TestData.Key, Collection<TestData.Value>>(); - Record pair = new Record(); - - while ((pair = iter.next(pair)) != null) { - - TestData.Key key = pair.getField(0, TestData.Key.class); - if (!map.containsKey(key)) { - map.put(new TestData.Key(key.getKey()), new ArrayList<TestData.Value>()); - } - - Collection<TestData.Value> values = map.get(key); - values.add(new TestData.Value(pair.getField(1, TestData.Value.class).getValue())); - } - - return map; - } - - static Map<Integer, Collection<Integer>> collectIntPairData(MutableObjectIterator<IntPair> iter) - throws Exception - { - Map<Integer, Collection<Integer>> map = new HashMap<Integer, Collection<Integer>>(); - IntPair pair = new IntPair(); - - while ((pair = iter.next(pair)) != null) { - - final int key = pair.getKey(); - final int value = pair.getValue(); - if (!map.containsKey(key)) { - map.put(key, new ArrayList<Integer>()); - } - - Collection<Integer> values = map.get(key); - values.add(value); - } - - return map; - } - - /** - * Private class used for storage of the expected matches in a hash-map. - */ - static class RecordMatch { - - private final Value left; - private final Value right; - - public RecordMatch(Value left, Value right) { - this.left = left; - this.right = right; - } - - @Override - public boolean equals(Object obj) { - RecordMatch o = (RecordMatch) obj; - return this.left.equals(o.left) && this.right.equals(o.right); - } - - @Override - public int hashCode() { - return this.left.hashCode() ^ this.right.hashCode(); - } - - @Override - public String toString() { - return left + ", " + right; - } - } - - /** - * Private class used for storage of the expected matches in a hash-map. - */ - static class RecordIntPairMatch - { - private final int left; - private final Value right; - - public RecordIntPairMatch(int left, Value right) { - this.left = left; - this.right = right; - } - - @Override - public boolean equals(Object obj) { - RecordIntPairMatch o = (RecordIntPairMatch) obj; - return this.left == o.left && this.right.equals(o.right); - } - - @Override - public int hashCode() { - return this.left ^ this.right.hashCode(); - } - - @Override - public String toString() { - return left + ", " + right; - } - } - - static final class RecordMatchRemovingJoin extends JoinFunction - { - private final Map<TestData.Key, Collection<RecordMatch>> toRemoveFrom; - - protected RecordMatchRemovingJoin(Map<TestData.Key, Collection<RecordMatch>> map) { - this.toRemoveFrom = map; - } - - @Override - public void join(Record rec1, Record rec2, Collector<Record> out) throws Exception - { - TestData.Key key = rec1.getField(0, TestData.Key.class); - TestData.Value value1 = rec1.getField(1, TestData.Value.class); - TestData.Value value2 = rec2.getField(1, TestData.Value.class); - //System.err.println("rec1 key = "+key+" rec2 key= "+rec2.getField(0, TestData.Key.class)); - Collection<RecordMatch> matches = this.toRemoveFrom.get(key); - if (matches == null) { - Assert.fail("Match " + key + " - " + value1 + ":" + value2 + " is unexpected."); - } - - Assert.assertTrue("Produced match was not contained: " + key + " - " + value1 + ":" + value2, - matches.remove(new RecordMatch(value1, value2))); - - if (matches.isEmpty()) { - this.toRemoveFrom.remove(key); - } - } - } - - static final class RecordIntPairMatchRemovingMatcher extends AbstractRichFunction implements FlatJoinFunction<IntPair, Record, Record> - { - private final Map<TestData.Key, Collection<RecordIntPairMatch>> toRemoveFrom; - - protected RecordIntPairMatchRemovingMatcher(Map<TestData.Key, Collection<RecordIntPairMatch>> map) { - this.toRemoveFrom = map; - } - - @Override - public void join(IntPair rec1, Record rec2, Collector<Record> out) throws Exception - { - final int k = rec1.getKey(); - final int v = rec1.getValue(); - - final TestData.Key key = rec2.getField(0, TestData.Key.class); - final TestData.Value value = rec2.getField(1, TestData.Value.class); - - Assert.assertTrue("Key does not match for matching IntPair Record combination.", k == key.getKey()); - - Collection<RecordIntPairMatch> matches = this.toRemoveFrom.get(key); - if (matches == null) { - Assert.fail("Match " + key + " - " + v + ":" + value + " is unexpected."); - } - - Assert.assertTrue("Produced match was not contained: " + key + " - " + v + ":" + value, - matches.remove(new RecordIntPairMatch(v, value))); - - if (matches.isEmpty()) { - this.toRemoveFrom.remove(key); - } - } - } - - static final class IntPairRecordPairComparator extends TypePairComparator<IntPair, Record> - { - private int reference; - - @Override - public void setReference(IntPair reference) { - this.reference = reference.getKey(); - } - - @Override - public boolean equalToReference(Record candidate) { - try { - final IntValue i = candidate.getField(0, IntValue.class); - return i.getValue() == this.reference; - } catch (NullPointerException npex) { - throw new NullKeyFieldException(); - } - } - - @Override - public int compareToReference(Record candidate) { - try { - final IntValue i = candidate.getField(0, IntValue.class); - return i.getValue() - this.reference; - } catch (NullPointerException npex) { - throw new NullKeyFieldException(); - } - } - } - - static final class RecordIntPairPairComparator extends TypePairComparator<Record, IntPair> - { - private int reference; - - @Override - public void setReference(Record reference) { - this.reference = reference.getField(0, IntValue.class).getValue(); - } - - @Override - public boolean equalToReference(IntPair candidate) { - return this.reference == candidate.getKey(); - } - - @Override - public int compareToReference(IntPair candidate) { - return candidate.getKey() - this.reference; - } - } -} http://git-wip-us.apache.org/repos/asf/flink/blob/b7b32a05/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/NonReusingHashMatchIteratorITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/NonReusingHashMatchIteratorITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/NonReusingHashMatchIteratorITCase.java new file mode 100644 index 0000000..f4d2251 --- /dev/null +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/hash/NonReusingHashMatchIteratorITCase.java @@ -0,0 +1,778 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.runtime.operators.hash; + +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.functions.FlatJoinFunction; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypePairComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.record.RecordComparator; +import org.apache.flink.api.common.typeutils.record.RecordPairComparator; +import org.apache.flink.api.common.typeutils.record.RecordSerializer; +import org.apache.flink.api.java.record.functions.JoinFunction; +import org.apache.flink.runtime.io.disk.iomanager.IOManager; +import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.memorymanager.DefaultMemoryManager; +import org.apache.flink.runtime.memorymanager.MemoryManager; +import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector; +import org.apache.flink.runtime.operators.testutils.DummyInvokable; +import org.apache.flink.runtime.operators.testutils.TestData; +import org.apache.flink.runtime.operators.testutils.TestData.Generator; +import org.apache.flink.runtime.operators.testutils.TestData.Generator.KeyMode; +import org.apache.flink.runtime.operators.testutils.TestData.Generator.ValueMode; +import org.apache.flink.runtime.operators.testutils.UniformIntPairGenerator; +import org.apache.flink.runtime.operators.testutils.UnionIterator; +import org.apache.flink.runtime.operators.testutils.types.IntPair; +import org.apache.flink.runtime.operators.testutils.types.IntPairComparator; +import org.apache.flink.runtime.operators.testutils.types.IntPairSerializer; +import org.apache.flink.types.IntValue; +import org.apache.flink.types.NullKeyFieldException; +import org.apache.flink.types.Record; +import org.apache.flink.types.Value; +import org.apache.flink.util.Collector; +import org.apache.flink.util.MutableObjectIterator; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +@SuppressWarnings({"serial", "deprecation"}) +public class NonReusingHashMatchIteratorITCase { + + private static final int MEMORY_SIZE = 16000000; // total memory + + private static final int INPUT_1_SIZE = 20000; + private static final int INPUT_2_SIZE = 1000; + + private static final long SEED1 = 561349061987311L; + private static final long SEED2 = 231434613412342L; + + private final AbstractInvokable parentTask = new DummyInvokable(); + + private IOManager ioManager; + private MemoryManager memoryManager; + + private TypeSerializer<Record> recordSerializer; + private TypeComparator<Record> record1Comparator; + private TypeComparator<Record> record2Comparator; + private TypePairComparator<Record, Record> recordPairComparator; + + private TypeSerializer<IntPair> pairSerializer; + private TypeComparator<IntPair> pairComparator; + private TypePairComparator<IntPair, Record> pairRecordPairComparator; + private TypePairComparator<Record, IntPair> recordPairPairComparator; + + + @SuppressWarnings("unchecked") + @Before + public void beforeTest() { + this.recordSerializer = RecordSerializer.get(); + + this.record1Comparator = new RecordComparator(new int[] {0}, new Class[] {TestData.Key.class}); + this.record2Comparator = new RecordComparator(new int[] {0}, new Class[] {TestData.Key.class}); + + this.recordPairComparator = new RecordPairComparator(new int[] {0}, new int[] {0}, new Class[] {TestData.Key.class}); + + this.pairSerializer = new IntPairSerializer(); + this.pairComparator = new IntPairComparator(); + this.pairRecordPairComparator = new IntPairRecordPairComparator(); + this.recordPairPairComparator = new RecordIntPairPairComparator(); + + this.memoryManager = new DefaultMemoryManager(MEMORY_SIZE, 1); + this.ioManager = new IOManagerAsync(); + } + + @After + public void afterTest() { + if (this.ioManager != null) { + this.ioManager.shutdown(); + if (!this.ioManager.isProperlyShutDown()) { + Assert.fail("I/O manager failed to properly shut down."); + } + this.ioManager = null; + } + + if (this.memoryManager != null) { + Assert.assertTrue("Memory Leak: Not all memory has been returned to the memory manager.", + this.memoryManager.verifyEmpty()); + this.memoryManager.shutdown(); + this.memoryManager = null; + } + } + + + @Test + public void testBuildFirst() { + try { + Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + + final TestData.GeneratorIterator input1 = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); + final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + // collect expected data + final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( + collectRecordData(input1), + collectRecordData(input2)); + + final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + // reset the generators + generator1.reset(); + generator2.reset(); + input1.reset(); + input2.reset(); + + // compare with iterator values + NonReusingBuildFirstHashMatchIterator<Record, Record, Record> iterator = + new NonReusingBuildFirstHashMatchIterator<Record, Record, Record>( + input1, input2, this.recordSerializer, this.record1Comparator, + this.recordSerializer, this.record2Comparator, this.recordPairComparator, + this.memoryManager, ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + @Test + public void testBuildFirstWithHighNumberOfCommonKeys() + { + // the size of the left and right inputs + final int INPUT_1_SIZE = 200; + final int INPUT_2_SIZE = 100; + + final int INPUT_1_DUPLICATES = 10; + final int INPUT_2_DUPLICATES = 2000; + final int DUPLICATE_KEY = 13; + + try { + Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + + final TestData.GeneratorIterator gen1Iter = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); + final TestData.GeneratorIterator gen2Iter = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + final TestData.ConstantValueIterator const1Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "LEFT String for Duplicate Keys", INPUT_1_DUPLICATES); + final TestData.ConstantValueIterator const2Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "RIGHT String for Duplicate Keys", INPUT_2_DUPLICATES); + + final List<MutableObjectIterator<Record>> inList1 = new ArrayList<MutableObjectIterator<Record>>(); + inList1.add(gen1Iter); + inList1.add(const1Iter); + + final List<MutableObjectIterator<Record>> inList2 = new ArrayList<MutableObjectIterator<Record>>(); + inList2.add(gen2Iter); + inList2.add(const2Iter); + + MutableObjectIterator<Record> input1 = new UnionIterator<Record>(inList1); + MutableObjectIterator<Record> input2 = new UnionIterator<Record>(inList2); + + + // collect expected data + final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( + collectRecordData(input1), + collectRecordData(input2)); + + // re-create the whole thing for actual processing + + // reset the generators and iterators + generator1.reset(); + generator2.reset(); + const1Iter.reset(); + const2Iter.reset(); + gen1Iter.reset(); + gen2Iter.reset(); + + inList1.clear(); + inList1.add(gen1Iter); + inList1.add(const1Iter); + + inList2.clear(); + inList2.add(gen2Iter); + inList2.add(const2Iter); + + input1 = new UnionIterator<Record>(inList1); + input2 = new UnionIterator<Record>(inList2); + + final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + NonReusingBuildFirstHashMatchIterator<Record, Record, Record> iterator = + new NonReusingBuildFirstHashMatchIterator<Record, Record, Record>( + input1, input2, this.recordSerializer, this.record1Comparator, + this.recordSerializer, this.record2Comparator, this.recordPairComparator, + this.memoryManager, ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + @Test + public void testBuildSecond() { + try { + Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + + final TestData.GeneratorIterator input1 = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); + final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + // collect expected data + final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( + collectRecordData(input1), + collectRecordData(input2)); + + final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + // reset the generators + generator1.reset(); + generator2.reset(); + input1.reset(); + input2.reset(); + + // compare with iterator values + NonReusingBuildSecondHashMatchIterator<Record, Record, Record> iterator = + new NonReusingBuildSecondHashMatchIterator<Record, Record, Record>( + input1, input2, this.recordSerializer, this.record1Comparator, + this.recordSerializer, this.record2Comparator, this.recordPairComparator, + this.memoryManager, ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + @Test + public void testBuildSecondWithHighNumberOfCommonKeys() + { + // the size of the left and right inputs + final int INPUT_1_SIZE = 200; + final int INPUT_2_SIZE = 100; + + final int INPUT_1_DUPLICATES = 10; + final int INPUT_2_DUPLICATES = 2000; + final int DUPLICATE_KEY = 13; + + try { + Generator generator1 = new Generator(SEED1, 500, 4096, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + + final TestData.GeneratorIterator gen1Iter = new TestData.GeneratorIterator(generator1, INPUT_1_SIZE); + final TestData.GeneratorIterator gen2Iter = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + final TestData.ConstantValueIterator const1Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "LEFT String for Duplicate Keys", INPUT_1_DUPLICATES); + final TestData.ConstantValueIterator const2Iter = new TestData.ConstantValueIterator(DUPLICATE_KEY, "RIGHT String for Duplicate Keys", INPUT_2_DUPLICATES); + + final List<MutableObjectIterator<Record>> inList1 = new ArrayList<MutableObjectIterator<Record>>(); + inList1.add(gen1Iter); + inList1.add(const1Iter); + + final List<MutableObjectIterator<Record>> inList2 = new ArrayList<MutableObjectIterator<Record>>(); + inList2.add(gen2Iter); + inList2.add(const2Iter); + + MutableObjectIterator<Record> input1 = new UnionIterator<Record>(inList1); + MutableObjectIterator<Record> input2 = new UnionIterator<Record>(inList2); + + + // collect expected data + final Map<TestData.Key, Collection<RecordMatch>> expectedMatchesMap = matchRecordValues( + collectRecordData(input1), + collectRecordData(input2)); + + // re-create the whole thing for actual processing + + // reset the generators and iterators + generator1.reset(); + generator2.reset(); + const1Iter.reset(); + const2Iter.reset(); + gen1Iter.reset(); + gen2Iter.reset(); + + inList1.clear(); + inList1.add(gen1Iter); + inList1.add(const1Iter); + + inList2.clear(); + inList2.add(gen2Iter); + inList2.add(const2Iter); + + input1 = new UnionIterator<Record>(inList1); + input2 = new UnionIterator<Record>(inList2); + + final JoinFunction matcher = new RecordMatchRemovingJoin(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + NonReusingBuildSecondHashMatchIterator<Record, Record, Record> iterator = + new NonReusingBuildSecondHashMatchIterator<Record, Record, Record>( + input1, input2, this.recordSerializer, this.record1Comparator, + this.recordSerializer, this.record2Comparator, this.recordPairComparator, + this.memoryManager, ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + @Test + public void testBuildFirstWithMixedDataTypes() { + try { + MutableObjectIterator<IntPair> input1 = new UniformIntPairGenerator(500, 40, false); + + final Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + // collect expected data + final Map<TestData.Key, Collection<RecordIntPairMatch>> expectedMatchesMap = matchRecordIntPairValues( + collectIntPairData(input1), + collectRecordData(input2)); + + final FlatJoinFunction<IntPair, Record, Record> matcher = new RecordIntPairMatchRemovingMatcher(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + // reset the generators + input1 = new UniformIntPairGenerator(500, 40, false); + generator2.reset(); + input2.reset(); + + // compare with iterator values + NonReusingBuildSecondHashMatchIterator<IntPair, Record, Record> iterator = + new NonReusingBuildSecondHashMatchIterator<IntPair, Record, Record>( + input1, input2, this.pairSerializer, this.pairComparator, + this.recordSerializer, this.record2Comparator, this.pairRecordPairComparator, + this.memoryManager, this.ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordIntPairMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + @Test + public void testBuildSecondWithMixedDataTypes() { + try { + MutableObjectIterator<IntPair> input1 = new UniformIntPairGenerator(500, 40, false); + + final Generator generator2 = new Generator(SEED2, 500, 2048, KeyMode.RANDOM, ValueMode.RANDOM_LENGTH); + final TestData.GeneratorIterator input2 = new TestData.GeneratorIterator(generator2, INPUT_2_SIZE); + + // collect expected data + final Map<TestData.Key, Collection<RecordIntPairMatch>> expectedMatchesMap = matchRecordIntPairValues( + collectIntPairData(input1), + collectRecordData(input2)); + + final FlatJoinFunction<IntPair, Record, Record> matcher = new RecordIntPairMatchRemovingMatcher(expectedMatchesMap); + final Collector<Record> collector = new DiscardingOutputCollector<Record>(); + + // reset the generators + input1 = new UniformIntPairGenerator(500, 40, false); + generator2.reset(); + input2.reset(); + + // compare with iterator values + NonReusingBuildFirstHashMatchIterator<IntPair, Record, Record> iterator = + new NonReusingBuildFirstHashMatchIterator<IntPair, Record, Record>( + input1, input2, this.pairSerializer, this.pairComparator, + this.recordSerializer, this.record2Comparator, this.recordPairPairComparator, + this.memoryManager, this.ioManager, this.parentTask, 1.0); + + iterator.open(); + + while (iterator.callWithNextKey(matcher, collector)); + + iterator.close(); + + // assert that each expected match was seen + for (Entry<TestData.Key, Collection<RecordIntPairMatch>> entry : expectedMatchesMap.entrySet()) { + if (!entry.getValue().isEmpty()) { + Assert.fail("Collection for key " + entry.getKey() + " is not empty"); + } + } + } + catch (Exception e) { + e.printStackTrace(); + Assert.fail("An exception occurred during the test: " + e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + // Utilities + // -------------------------------------------------------------------------------------------- + + + + static Map<TestData.Key, Collection<RecordMatch>> matchRecordValues( + Map<TestData.Key, Collection<TestData.Value>> leftMap, + Map<TestData.Key, Collection<TestData.Value>> rightMap) + { + Map<TestData.Key, Collection<RecordMatch>> map = new HashMap<TestData.Key, Collection<RecordMatch>>(); + + for (TestData.Key key : leftMap.keySet()) { + Collection<TestData.Value> leftValues = leftMap.get(key); + Collection<TestData.Value> rightValues = rightMap.get(key); + + if (rightValues == null) { + continue; + } + + if (!map.containsKey(key)) { + map.put(key, new ArrayList<RecordMatch>()); + } + + Collection<RecordMatch> matchedValues = map.get(key); + + for (TestData.Value leftValue : leftValues) { + for (TestData.Value rightValue : rightValues) { + matchedValues.add(new RecordMatch(leftValue, rightValue)); + } + } + } + + return map; + } + + static Map<TestData.Key, Collection<RecordIntPairMatch>> matchRecordIntPairValues( + Map<Integer, Collection<Integer>> leftMap, + Map<TestData.Key, Collection<TestData.Value>> rightMap) + { + final Map<TestData.Key, Collection<RecordIntPairMatch>> map = new HashMap<TestData.Key, Collection<RecordIntPairMatch>>(); + + for (Integer i : leftMap.keySet()) { + + final TestData.Key key = new TestData.Key(i.intValue()); + + final Collection<Integer> leftValues = leftMap.get(i); + final Collection<TestData.Value> rightValues = rightMap.get(key); + + if (rightValues == null) { + continue; + } + + if (!map.containsKey(key)) { + map.put(key, new ArrayList<RecordIntPairMatch>()); + } + + final Collection<RecordIntPairMatch> matchedValues = map.get(key); + + for (Integer v : leftValues) { + for (TestData.Value val : rightValues) { + matchedValues.add(new RecordIntPairMatch(v, val)); + } + } + } + + return map; + } + + + static Map<TestData.Key, Collection<TestData.Value>> collectRecordData(MutableObjectIterator<Record> iter) + throws Exception + { + Map<TestData.Key, Collection<TestData.Value>> map = new HashMap<TestData.Key, Collection<TestData.Value>>(); + Record pair = new Record(); + + while ((pair = iter.next(pair)) != null) { + + TestData.Key key = pair.getField(0, TestData.Key.class); + if (!map.containsKey(key)) { + map.put(new TestData.Key(key.getKey()), new ArrayList<TestData.Value>()); + } + + Collection<TestData.Value> values = map.get(key); + values.add(new TestData.Value(pair.getField(1, TestData.Value.class).getValue())); + } + + return map; + } + + static Map<Integer, Collection<Integer>> collectIntPairData(MutableObjectIterator<IntPair> iter) + throws Exception + { + Map<Integer, Collection<Integer>> map = new HashMap<Integer, Collection<Integer>>(); + IntPair pair = new IntPair(); + + while ((pair = iter.next(pair)) != null) { + + final int key = pair.getKey(); + final int value = pair.getValue(); + if (!map.containsKey(key)) { + map.put(key, new ArrayList<Integer>()); + } + + Collection<Integer> values = map.get(key); + values.add(value); + } + + return map; + } + + /** + * Private class used for storage of the expected matches in a hash-map. + */ + static class RecordMatch { + + private final Value left; + private final Value right; + + public RecordMatch(Value left, Value right) { + this.left = left; + this.right = right; + } + + @Override + public boolean equals(Object obj) { + RecordMatch o = (RecordMatch) obj; + return this.left.equals(o.left) && this.right.equals(o.right); + } + + @Override + public int hashCode() { + return this.left.hashCode() ^ this.right.hashCode(); + } + + @Override + public String toString() { + return left + ", " + right; + } + } + + /** + * Private class used for storage of the expected matches in a hash-map. + */ + static class RecordIntPairMatch + { + private final int left; + private final Value right; + + public RecordIntPairMatch(int left, Value right) { + this.left = left; + this.right = right; + } + + @Override + public boolean equals(Object obj) { + RecordIntPairMatch o = (RecordIntPairMatch) obj; + return this.left == o.left && this.right.equals(o.right); + } + + @Override + public int hashCode() { + return this.left ^ this.right.hashCode(); + } + + @Override + public String toString() { + return left + ", " + right; + } + } + + static final class RecordMatchRemovingJoin extends JoinFunction + { + private final Map<TestData.Key, Collection<RecordMatch>> toRemoveFrom; + + protected RecordMatchRemovingJoin(Map<TestData.Key, Collection<RecordMatch>> map) { + this.toRemoveFrom = map; + } + + @Override + public void join(Record rec1, Record rec2, Collector<Record> out) throws Exception + { + TestData.Key key = rec1.getField(0, TestData.Key.class); + TestData.Value value1 = rec1.getField(1, TestData.Value.class); + TestData.Value value2 = rec2.getField(1, TestData.Value.class); + //System.err.println("rec1 key = "+key+" rec2 key= "+rec2.getField(0, TestData.Key.class)); + Collection<RecordMatch> matches = this.toRemoveFrom.get(key); + if (matches == null) { + Assert.fail("Match " + key + " - " + value1 + ":" + value2 + " is unexpected."); + } + + Assert.assertTrue("Produced match was not contained: " + key + " - " + value1 + ":" + value2, + matches.remove(new RecordMatch(value1, value2))); + + if (matches.isEmpty()) { + this.toRemoveFrom.remove(key); + } + } + } + + static final class RecordIntPairMatchRemovingMatcher extends AbstractRichFunction implements FlatJoinFunction<IntPair, Record, Record> + { + private final Map<TestData.Key, Collection<RecordIntPairMatch>> toRemoveFrom; + + protected RecordIntPairMatchRemovingMatcher(Map<TestData.Key, Collection<RecordIntPairMatch>> map) { + this.toRemoveFrom = map; + } + + @Override + public void join(IntPair rec1, Record rec2, Collector<Record> out) throws Exception + { + final int k = rec1.getKey(); + final int v = rec1.getValue(); + + final TestData.Key key = rec2.getField(0, TestData.Key.class); + final TestData.Value value = rec2.getField(1, TestData.Value.class); + + Assert.assertTrue("Key does not match for matching IntPair Record combination.", k == key.getKey()); + + Collection<RecordIntPairMatch> matches = this.toRemoveFrom.get(key); + if (matches == null) { + Assert.fail("Match " + key + " - " + v + ":" + value + " is unexpected."); + } + + Assert.assertTrue("Produced match was not contained: " + key + " - " + v + ":" + value, + matches.remove(new RecordIntPairMatch(v, value))); + + if (matches.isEmpty()) { + this.toRemoveFrom.remove(key); + } + } + } + + static final class IntPairRecordPairComparator extends TypePairComparator<IntPair, Record> + { + private int reference; + + @Override + public void setReference(IntPair reference) { + this.reference = reference.getKey(); + } + + @Override + public boolean equalToReference(Record candidate) { + try { + final IntValue i = candidate.getField(0, IntValue.class); + return i.getValue() == this.reference; + } catch (NullPointerException npex) { + throw new NullKeyFieldException(); + } + } + + @Override + public int compareToReference(Record candidate) { + try { + final IntValue i = candidate.getField(0, IntValue.class); + return i.getValue() - this.reference; + } catch (NullPointerException npex) { + throw new NullKeyFieldException(); + } + } + } + + static final class RecordIntPairPairComparator extends TypePairComparator<Record, IntPair> + { + private int reference; + + @Override + public void setReference(Record reference) { + this.reference = reference.getField(0, IntValue.class).getValue(); + } + + @Override + public boolean equalToReference(IntPair candidate) { + return this.reference == candidate.getKey(); + } + + @Override + public int compareToReference(IntPair candidate) { + return candidate.getKey() - this.reference; + } + } +}
