Repository: flink Updated Branches: refs/heads/master 4a49a73a7 -> e93e0cb86
http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-java/src/test/java/org/apache/flink/api/java/record/ReduceWrappingFunctionTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/record/ReduceWrappingFunctionTest.java b/flink-java/src/test/java/org/apache/flink/api/java/record/ReduceWrappingFunctionTest.java index f413b81..2216217 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/record/ReduceWrappingFunctionTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/record/ReduceWrappingFunctionTest.java @@ -86,7 +86,7 @@ public class ReduceWrappingFunctionTest { target.clear(); // test combine - ((FlatCombineFunction<Record>) reducer).combine(source, collector); + ((FlatCombineFunction<Record, Record>) reducer).combine(source, collector); assertEquals(2, target.size()); assertEquals(new IntValue(42), target.get(0).getField(0, IntValue.class)); assertEquals(new LongValue(11), target.get(0).getField(1, LongValue.class)); @@ -138,7 +138,7 @@ public class ReduceWrappingFunctionTest { target.clear(); // test combine - ((FlatCombineFunction<Record>) reducer).combine(source, collector); + ((FlatCombineFunction<Record, Record>) reducer).combine(source, collector); assertEquals(2, target.size()); assertEquals(new IntValue(42), target.get(0).getField(0, IntValue.class)); assertEquals(new LongValue(11), target.get(0).getField(1, LongValue.class)); http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupCombineDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupCombineDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupCombineDriver.java new file mode 100644 index 0000000..7d87a6b --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupCombineDriver.java @@ -0,0 +1,127 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + + +package org.apache.flink.runtime.operators; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.functions.FlatCombineFunction; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.runtime.util.NonReusingMutableToRegularIteratorWrapper; +import org.apache.flink.runtime.util.ReusingMutableToRegularIteratorWrapper; +import org.apache.flink.util.Collector; +import org.apache.flink.util.MutableObjectIterator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** +* Non-chained driver for the partial group reduce operator that acts like a combiner with a custom output type OUT. +* Like @org.apache.flink.runtime.operators.GroupCombineDriver but without grouping and sorting. May emit partially +* reduced results. +* +* @see org.apache.flink.api.common.functions.FlatCombineFunction +*/ +public class AllGroupCombineDriver<IN, OUT> implements PactDriver<FlatCombineFunction<IN, OUT>, OUT> { + + private static final Logger LOG = LoggerFactory.getLogger(AllGroupCombineDriver.class); + + private PactTaskContext<FlatCombineFunction<IN, OUT>, OUT> taskContext; + + private boolean objectReuseEnabled = false; + + // ------------------------------------------------------------------------ + + @Override + public void setup(PactTaskContext<FlatCombineFunction<IN, OUT>, OUT> context) { + this.taskContext = context; + } + + @Override + public int getNumberOfInputs() { + return 1; + } + + @Override + public Class<FlatCombineFunction<IN, OUT>> getStubType() { + @SuppressWarnings("unchecked") + final Class<FlatCombineFunction<IN, OUT>> clazz = (Class<FlatCombineFunction<IN, OUT>>) (Class<?>) FlatCombineFunction.class; + return clazz; + } + + @Override + public int getNumberOfDriverComparators() { + return 0; + } + + @Override + public void prepare() throws Exception { + final DriverStrategy driverStrategy = this.taskContext.getTaskConfig().getDriverStrategy(); + if(driverStrategy != DriverStrategy.ALL_GROUP_COMBINE){ + throw new Exception("Invalid strategy " + driverStrategy + " for " + + "GroupCombine."); + } + + ExecutionConfig executionConfig = taskContext.getExecutionConfig(); + this.objectReuseEnabled = executionConfig.isObjectReuseEnabled(); + + if (LOG.isDebugEnabled()) { + LOG.debug("GroupCombineDriver object reuse: " + (this.objectReuseEnabled ? "ENABLED" : "DISABLED") + "."); + } + } + + @Override + public void run() throws Exception { + if (LOG.isDebugEnabled()) { + LOG.debug("AllGroupCombine starting."); + } + + final TypeSerializerFactory<IN> serializerFactory = this.taskContext.getInputSerializer(0); + TypeSerializer<IN> serializer = serializerFactory.getSerializer(); + + final MutableObjectIterator<IN> in = this.taskContext.getInput(0); + final FlatCombineFunction<IN, OUT> reducer = this.taskContext.getStub(); + final Collector<OUT> output = this.taskContext.getOutputCollector(); + + if (objectReuseEnabled) { + final ReusingMutableToRegularIteratorWrapper<IN> inIter = new ReusingMutableToRegularIteratorWrapper<IN>(in, serializer); + + if (inIter.hasNext()) { + reducer.combine(inIter, output); + + } + + } else { + final NonReusingMutableToRegularIteratorWrapper<IN> inIter = new NonReusingMutableToRegularIteratorWrapper<IN>(in, serializer); + + if (inIter.hasNext()) { + reducer.combine(inIter, output); + } + } + + } + + @Override + public void cleanup() throws Exception { + } + + @Override + public void cancel() { + } +} + http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupReduceDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupReduceDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupReduceDriver.java index 854dbd5..9b9b5f6 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupReduceDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/AllGroupReduceDriver.java @@ -32,9 +32,13 @@ import org.apache.flink.util.Collector; import org.apache.flink.util.MutableObjectIterator; /** - * GroupReduce task which is executed by a Nephele task manager. The task has a + * GroupReduceDriver task which is executed by a Nephele task manager. The task has a * single input and one or multiple outputs. It is provided with a GroupReduceFunction - * implementation. + * implementation or a RichGroupFunction. This Driver performs + * multiple tasks depending on the DriverStrategy. In case of a ALL_GROUP_REDUCE_COMBINE + * it uses the combine function of the supplied user function. In case + * of the ALL_GROUP_REDUCE, it uses the reduce function of the supplied user function to + * process all elements. In either case, the function is executed on all elements. * <p> * The GroupReduceTask creates a iterator over all records from its input. The iterator returns all records grouped by their * key. The iterator is handed to the <code>reduce()</code> method of the GroupReduceFunction. @@ -85,15 +89,19 @@ public class AllGroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunct public void prepare() throws Exception { final TaskConfig config = this.taskContext.getTaskConfig(); this.strategy = config.getDriverStrategy(); - - if (strategy == DriverStrategy.ALL_GROUP_COMBINE) { - if (!(this.taskContext.getStub() instanceof FlatCombineFunction)) { - throw new Exception("Using combiner on a UDF that does not implement the combiner interface " + FlatCombineFunction.class.getName()); - } - } - else if (strategy != DriverStrategy.ALL_GROUP_REDUCE) { - throw new Exception("Unrecognized driver strategy for AllGroupReduce driver: " + config.getDriverStrategy().name()); + + switch (this.strategy) { + case ALL_GROUP_REDUCE_COMBINE: + if (!(this.taskContext.getStub() instanceof FlatCombineFunction)) { + throw new Exception("Using combiner on a UDF that does not implement the combiner interface " + FlatCombineFunction.class.getName()); + } + case ALL_GROUP_REDUCE: + case ALL_GROUP_COMBINE: + break; + default: + throw new Exception("Unrecognized driver strategy for AllGroupReduce driver: " + this.strategy.name()); } + this.serializer = this.taskContext.<IT>getInputSerializer(0).getSerializer(); this.input = this.taskContext.getInput(0); @@ -108,7 +116,7 @@ public class AllGroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunct @Override public void run() throws Exception { if (LOG.isDebugEnabled()) { - LOG.debug(this.taskContext.formatLogString("AllGroupReduce preprocessing done. Running Reducer code.")); + LOG.debug(this.taskContext.formatLogString("AllGroupReduceDriver preprocessing done. Running Reducer code.")); } if (objectReuseEnabled) { @@ -120,10 +128,12 @@ public class AllGroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunct final GroupReduceFunction<IT, OT> reducer = this.taskContext.getStub(); final Collector<OT> output = this.taskContext.getOutputCollector(); reducer.reduce(inIter, output); - } else { - @SuppressWarnings("unchecked") final FlatCombineFunction<IT> combiner = (FlatCombineFunction<IT>) this.taskContext.getStub(); - @SuppressWarnings("unchecked") final Collector<IT> output = (Collector<IT>) this.taskContext.getOutputCollector(); + } else if (strategy == DriverStrategy.ALL_GROUP_REDUCE_COMBINE || strategy == DriverStrategy.ALL_GROUP_COMBINE) { + @SuppressWarnings("unchecked") final FlatCombineFunction<IT, OT> combiner = (FlatCombineFunction<IT, OT>) this.taskContext.getStub(); + final Collector<OT> output = this.taskContext.getOutputCollector(); combiner.combine(inIter, output); + } else { + throw new Exception("The strategy " + strategy + " is unknown to this driver."); } } @@ -136,10 +146,12 @@ public class AllGroupReduceDriver<IT, OT> implements PactDriver<GroupReduceFunct final GroupReduceFunction<IT, OT> reducer = this.taskContext.getStub(); final Collector<OT> output = this.taskContext.getOutputCollector(); reducer.reduce(inIter, output); - } else { - @SuppressWarnings("unchecked") final FlatCombineFunction<IT> combiner = (FlatCombineFunction<IT>) this.taskContext.getStub(); - @SuppressWarnings("unchecked") final Collector<IT> output = (Collector<IT>) this.taskContext.getOutputCollector(); + } else if (strategy == DriverStrategy.ALL_GROUP_REDUCE_COMBINE || strategy == DriverStrategy.ALL_GROUP_COMBINE) { + @SuppressWarnings("unchecked") final FlatCombineFunction<IT, OT> combiner = (FlatCombineFunction<IT, OT>) this.taskContext.getStub(); + final Collector<OT> output = this.taskContext.getOutputCollector(); combiner.combine(inIter, output); + } else { + throw new Exception("The strategy " + strategy + " is unknown to this driver."); } } } http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java index ae9b474..d5b131e 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/DriverStrategy.java @@ -55,7 +55,7 @@ public enum DriverStrategy { // group everything together into one group and apply the GroupReduce function ALL_GROUP_REDUCE(AllGroupReduceDriver.class, null, PIPELINED, 0), // group everything together into one group and apply the GroupReduce's combine function - ALL_GROUP_COMBINE(AllGroupReduceDriver.class, null, PIPELINED, 0), + ALL_GROUP_REDUCE_COMBINE(AllGroupReduceDriver.class, null, PIPELINED, 0), // grouping the inputs and apply the Reduce Function SORTED_REDUCE(ReduceDriver.class, null, PIPELINED, 1), @@ -67,6 +67,9 @@ public enum DriverStrategy { // partially grouping inputs (best effort resulting possibly in duplicates --> combiner) SORTED_GROUP_COMBINE(GroupReduceCombineDriver.class, SynchronousChainedCombineDriver.class, MATERIALIZING, 2), + // group combine on all inputs within a partition (without grouping) + ALL_GROUP_COMBINE(AllGroupCombineDriver.class, null, PIPELINED, 0), + // both inputs are merged, but materialized to the side for block-nested-loop-join among values with equal key MERGE(MatchDriver.class, null, MATERIALIZING, MATERIALIZING, 2), http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/GroupReduceCombineDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/GroupReduceCombineDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/GroupReduceCombineDriver.java index be0c9c4..dacd436 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/GroupReduceCombineDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/GroupReduceCombineDriver.java @@ -41,36 +41,43 @@ import java.io.IOException; import java.util.List; /** - * Combine operator, standalone (not chained) - * <p> + * Non-chained combine driver which is used for a CombineGroup transformation or a GroupReduce transformation where + * the user supplied a RichGroupReduceFunction with a combine method. The combining is performed in memory with a + * lazy approach which only combines elements which currently fit in the sorter. This may lead to a partial solution. + * In the case of the RichGroupReduceFunction this partial result is transformed into a proper deterministic result. + * The CombineGroup uses the FlatCombineFunction interface which allows to combine values of type <IN> to any type + * of type <OUT>. In contrast, the RichGroupReduceFunction requires the combine method to have the same input and + * output type to be able to reduce the elements after the combine from <IN> to <OUT>. + * * The CombineTask uses a combining iterator over its input. The output of the iterator is emitted. * - * @param <T> The data type consumed and produced by the combiner. + * @param <IN> The data type consumed by the combiner. + * @param <OUT> The data type produced by the combiner. */ -public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFunction<T>, T> { +public class GroupReduceCombineDriver<IN, OUT> implements PactDriver<FlatCombineFunction<IN, OUT>, OUT> { private static final Logger LOG = LoggerFactory.getLogger(GroupReduceCombineDriver.class); /** Fix length records with a length below this threshold will be in-place sorted, if possible. */ private static final int THRESHOLD_FOR_IN_PLACE_SORTING = 32; - private PactTaskContext<FlatCombineFunction<T>, T> taskContext; + private PactTaskContext<FlatCombineFunction<IN, OUT>, OUT> taskContext; - private InMemorySorter<T> sorter; + private InMemorySorter<IN> sorter; - private FlatCombineFunction<T> combiner; + private FlatCombineFunction<IN, OUT> combiner; - private TypeSerializer<T> serializer; + private TypeSerializer<IN> serializer; - private TypeComparator<T> sortingComparator; + private TypeComparator<IN> sortingComparator; - private TypeComparator<T> groupingComparator; + private TypeComparator<IN> groupingComparator; private QuickSort sortAlgo = new QuickSort(); private MemoryManager memManager; - private Collector<T> output; + private Collector<OUT> output; private volatile boolean running = true; @@ -79,7 +86,7 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti // ------------------------------------------------------------------------ @Override - public void setup(PactTaskContext<FlatCombineFunction<T>, T> context) { + public void setup(PactTaskContext<FlatCombineFunction<IN, OUT>, OUT> context) { this.taskContext = context; this.running = true; } @@ -90,9 +97,9 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti } @Override - public Class<FlatCombineFunction<T>> getStubType() { + public Class<FlatCombineFunction<IN, OUT>> getStubType() { @SuppressWarnings("unchecked") - final Class<FlatCombineFunction<T>> clazz = (Class<FlatCombineFunction<T>>) (Class<?>) FlatCombineFunction.class; + final Class<FlatCombineFunction<IN, OUT>> clazz = (Class<FlatCombineFunction<IN, OUT>>) (Class<?>) FlatCombineFunction.class; return clazz; } @@ -103,15 +110,16 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti @Override public void prepare() throws Exception { - if(this.taskContext.getTaskConfig().getDriverStrategy() != DriverStrategy.SORTED_GROUP_COMBINE){ - throw new Exception("Invalid strategy " + this.taskContext.getTaskConfig().getDriverStrategy() + " for " + + final DriverStrategy driverStrategy = this.taskContext.getTaskConfig().getDriverStrategy(); + if(driverStrategy != DriverStrategy.SORTED_GROUP_COMBINE){ + throw new Exception("Invalid strategy " + driverStrategy + " for " + "group reduce combinder."); } this.memManager = this.taskContext.getMemoryManager(); final int numMemoryPages = memManager.computeNumberOfPages(this.taskContext.getTaskConfig().getRelativeMemoryDriver()); - final TypeSerializerFactory<T> serializerFactory = this.taskContext.getInputSerializer(0); + final TypeSerializerFactory<IN> serializerFactory = this.taskContext.getInputSerializer(0); this.serializer = serializerFactory.getSerializer(); this.sortingComparator = this.taskContext.getDriverComparator(0); this.groupingComparator = this.taskContext.getDriverComparator(1); @@ -125,9 +133,9 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti if (this.sortingComparator.supportsSerializationWithKeyNormalization() && this.serializer.getLength() > 0 && this.serializer.getLength() <= THRESHOLD_FOR_IN_PLACE_SORTING) { - this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.sortingComparator, memory); + this.sorter = new FixedLengthRecordSorter<IN>(this.serializer, this.sortingComparator, memory); } else { - this.sorter = new NormalizedKeySorter<T>(this.serializer, this.sortingComparator.duplicate(), memory); + this.sorter = new NormalizedKeySorter<IN>(this.serializer, this.sortingComparator.duplicate(), memory); } ExecutionConfig executionConfig = taskContext.getExecutionConfig(); @@ -144,10 +152,10 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti LOG.debug("Combiner starting."); } - final MutableObjectIterator<T> in = this.taskContext.getInput(0); - final TypeSerializer<T> serializer = this.serializer; + final MutableObjectIterator<IN> in = this.taskContext.getInput(0); + final TypeSerializer<IN> serializer = this.serializer; - T value = serializer.createInstance(); + IN value = serializer.createInstance(); while (running && (value = in.next(value)) != null) { @@ -171,17 +179,17 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti } private void sortAndCombine() throws Exception { - final InMemorySorter<T> sorter = this.sorter; + final InMemorySorter<IN> sorter = this.sorter; if (objectReuseEnabled) { if (!sorter.isEmpty()) { this.sortAlgo.sort(sorter); - final ReusingKeyGroupedIterator<T> keyIter = - new ReusingKeyGroupedIterator<T>(sorter.getIterator(), this.serializer, this.groupingComparator); + final ReusingKeyGroupedIterator<IN> keyIter = + new ReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.serializer, this.groupingComparator); - final FlatCombineFunction<T> combiner = this.combiner; - final Collector<T> output = this.output; + final FlatCombineFunction<IN, OUT> combiner = this.combiner; + final Collector<OUT> output = this.output; // iterate over key groups while (this.running && keyIter.nextKey()) { @@ -192,11 +200,11 @@ public class GroupReduceCombineDriver<T> implements PactDriver<FlatCombineFuncti if (!sorter.isEmpty()) { this.sortAlgo.sort(sorter); - final NonReusingKeyGroupedIterator<T> keyIter = - new NonReusingKeyGroupedIterator<T>(sorter.getIterator(), this.groupingComparator); + final NonReusingKeyGroupedIterator<IN> keyIter = + new NonReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.groupingComparator); - final FlatCombineFunction<T> combiner = this.combiner; - final Collector<T> output = this.output; + final FlatCombineFunction<IN, OUT> combiner = this.combiner; + final Collector<OUT> output = this.output; // iterate over key groups while (this.running && keyIter.nextKey()) { http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/PactTaskContext.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/PactTaskContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/PactTaskContext.java index ffe09cb..b296506 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/PactTaskContext.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/PactTaskContext.java @@ -33,7 +33,7 @@ import org.apache.flink.util.MutableObjectIterator; /** * A runtime task is the task that is executed by the flink engine inside a task vertex. * It typically has a {@link PactDriver}, and optionally multiple chained drivers. In addition, it - * deals with the runtime setup and teardown and the control-flow logic. The later appears especially + * deals with the runtime setup and teardown and the control-flow logic. The latter appears especially * in the case of iterations. * * @param <S> The UDF type. http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/GroupCombineChainedDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/GroupCombineChainedDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/GroupCombineChainedDriver.java new file mode 100644 index 0000000..ff397c2 --- /dev/null +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/GroupCombineChainedDriver.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.runtime.operators.chaining; + +import org.apache.flink.api.common.functions.Function; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.util.FunctionUtils; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.memory.MemorySegment; +import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable; +import org.apache.flink.runtime.memorymanager.MemoryManager; +import org.apache.flink.runtime.operators.RegularPactTask; +import org.apache.flink.runtime.operators.sort.FixedLengthRecordSorter; +import org.apache.flink.runtime.operators.sort.InMemorySorter; +import org.apache.flink.runtime.operators.sort.NormalizedKeySorter; +import org.apache.flink.runtime.operators.sort.QuickSort; +import org.apache.flink.runtime.util.NonReusingKeyGroupedIterator; +import org.apache.flink.runtime.util.ReusingKeyGroupedIterator; +import org.apache.flink.util.Collector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; + +/** + * Chained variant of the GroupCombineDriver + * + * Acts like a combiner with a custom output type OUT. + * + * Sorting and reducing of the elements is performed invididually for each partition without data exchange. This may + * lead to a partial group reduce. + * + * @param <IN> The data type consumed + * @param <OUT> The data type produced + */ +public class GroupCombineChainedDriver<IN, OUT> extends ChainedDriver<IN, OUT> { + + private static final Logger LOG = LoggerFactory.getLogger(GroupCombineChainedDriver.class); + + /** + * Fix length records with a length below this threshold will be in-place sorted, if possible. + */ + private static final int THRESHOLD_FOR_IN_PLACE_SORTING = 32; + + // -------------------------------------------------------------------------------------------- + + private InMemorySorter<IN> sorter; + + private GroupReduceFunction<IN, OUT> reducer; + + private TypeSerializer<IN> serializer; + + private TypeComparator<IN> sortingComparator; + + private TypeComparator<IN> groupingComparator; + + private AbstractInvokable parent; + + private QuickSort sortAlgo = new QuickSort(); + + private MemoryManager memManager; + + private volatile boolean running = true; + + // -------------------------------------------------------------------------------------------- + + @Override + public void setup(AbstractInvokable parent) { + this.parent = parent; + + @SuppressWarnings("unchecked") + final GroupReduceFunction<IN, OUT> combiner = + RegularPactTask.instantiateUserCode(this.config, userCodeClassLoader, GroupReduceFunction.class); + this.reducer = combiner; + FunctionUtils.setFunctionRuntimeContext(combiner, getUdfRuntimeContext()); + } + + @Override + public void openTask() throws Exception { + // open the stub first + final Configuration stubConfig = this.config.getStubParameters(); + RegularPactTask.openUserCode(this.reducer, stubConfig); + + // ----------------- Set up the asynchronous sorter ------------------------- + + this.memManager = this.parent.getEnvironment().getMemoryManager(); + final int numMemoryPages = memManager.computeNumberOfPages(this.config.getRelativeMemoryDriver()); + + // instantiate the serializer / comparator + final TypeSerializerFactory<IN> serializerFactory = this.config.getInputSerializer(0, this.userCodeClassLoader); + final TypeComparatorFactory<IN> sortingComparatorFactory = this.config.getDriverComparator(0, this.userCodeClassLoader); + final TypeComparatorFactory<IN> groupingComparatorFactory = this.config.getDriverComparator(1, this.userCodeClassLoader); + this.serializer = serializerFactory.getSerializer(); + this.sortingComparator = sortingComparatorFactory.createComparator(); + this.groupingComparator = groupingComparatorFactory.createComparator(); + + final List<MemorySegment> memory = this.memManager.allocatePages(this.parent, numMemoryPages); + + // instantiate a fix-length in-place sorter, if possible, otherwise the out-of-place sorter + if (this.sortingComparator.supportsSerializationWithKeyNormalization() && + this.serializer.getLength() > 0 && this.serializer.getLength() <= THRESHOLD_FOR_IN_PLACE_SORTING) + { + this.sorter = new FixedLengthRecordSorter<IN>(this.serializer, this.sortingComparator, memory); + } else { + this.sorter = new NormalizedKeySorter<IN>(this.serializer, this.sortingComparator.duplicate(), memory); + } + + if (LOG.isDebugEnabled()) { + LOG.debug("SynchronousChainedCombineDriver object reuse: " + (this.objectReuseEnabled ? "ENABLED" : "DISABLED") + "."); + } + } + + @Override + public void closeTask() throws Exception { + this.memManager.release(this.sorter.dispose()); + + if (!this.running) { + return; + } + + RegularPactTask.closeUserCode(this.reducer); + } + + @Override + public void cancelTask() { + this.running = false; + this.memManager.release(this.sorter.dispose()); + } + + // -------------------------------------------------------------------------------------------- + + public Function getStub() { + return this.reducer; + } + + public String getTaskName() { + return this.taskName; + } + + @Override + public void collect(IN record) { + // try writing to the sorter first + try { + if (this.sorter.write(record)) { + return; + } + } catch (IOException e) { + throw new ExceptionInChainedStubException(this.taskName, e); + } + + // do the actual sorting + try { + sortAndReduce(); + } catch (Exception e) { + throw new ExceptionInChainedStubException(this.taskName, e); + } + this.sorter.reset(); + + try { + if (!this.sorter.write(record)) { + throw new IOException("Cannot write record to fresh sort buffer. Record too large."); + } + } catch (IOException e) { + throw new ExceptionInChainedStubException(this.taskName, e); + } + } + + // -------------------------------------------------------------------------------------------- + + @Override + public void close() { + try { + sortAndReduce(); + } catch (Exception e) { + throw new ExceptionInChainedStubException(this.taskName, e); + } + + this.outputCollector.close(); + } + + private void sortAndReduce() throws Exception { + final InMemorySorter<IN> sorter = this.sorter; + + if (objectReuseEnabled) { + if (!sorter.isEmpty()) { + this.sortAlgo.sort(sorter); + // run the reducer + final ReusingKeyGroupedIterator<IN> keyIter = new ReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.serializer, this.groupingComparator); + + + // cache references on the stack + final GroupReduceFunction<IN, OUT> stub = this.reducer; + final Collector<OUT> output = this.outputCollector; + + // run stub implementation + while (this.running && keyIter.nextKey()) { + stub.reduce(keyIter.getValues(), output); + } + } + } else { + if (!sorter.isEmpty()) { + this.sortAlgo.sort(sorter); + // run the reducer + final NonReusingKeyGroupedIterator<IN> keyIter = new NonReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.groupingComparator); + + + // cache references on the stack + final GroupReduceFunction<IN, OUT> stub = this.reducer; + final Collector<OUT> output = this.outputCollector; + + // run stub implementation + while (this.running && keyIter.nextKey()) { + stub.reduce(keyIter.getValues(), output); + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/SynchronousChainedCombineDriver.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/SynchronousChainedCombineDriver.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/SynchronousChainedCombineDriver.java index dde6fe6..7e36b49 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/SynchronousChainedCombineDriver.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/chaining/SynchronousChainedCombineDriver.java @@ -43,7 +43,17 @@ import org.apache.flink.util.Collector; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { +/** + * The chained variant of the combine driver which is also implemented in GroupReduceCombineDriver. In contrast to the + * GroupReduceCombineDriver, this driver's purpose is only to combine the values received in the chain. It is used by + * the GroupReduce and the CombineGroup transformation. + * + * @see org.apache.flink.runtime.operators.GroupReduceCombineDriver + * @param <IN> The data type consumed by the combiner. + * @param <OUT> The data type produced by the combiner. + */ + +public class SynchronousChainedCombineDriver<IN, OUT> extends ChainedDriver<IN, OUT> { private static final Logger LOG = LoggerFactory.getLogger(SynchronousChainedCombineDriver.class); @@ -55,15 +65,15 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { // -------------------------------------------------------------------------------------------- - private InMemorySorter<T> sorter; + private InMemorySorter<IN> sorter; - private FlatCombineFunction<T> combiner; + private FlatCombineFunction<IN, OUT> combiner; - private TypeSerializer<T> serializer; + private TypeSerializer<IN> serializer; - private TypeComparator<T> sortingComparator; + private TypeComparator<IN> sortingComparator; - private TypeComparator<T> groupingComparator; + private TypeComparator<IN> groupingComparator; private AbstractInvokable parent; @@ -80,7 +90,7 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { this.parent = parent; @SuppressWarnings("unchecked") - final FlatCombineFunction<T> combiner = + final FlatCombineFunction<IN, OUT> combiner = RegularPactTask.instantiateUserCode(this.config, userCodeClassLoader, FlatCombineFunction.class); this.combiner = combiner; FunctionUtils.setFunctionRuntimeContext(combiner, getUdfRuntimeContext()); @@ -98,9 +108,9 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { final int numMemoryPages = memManager.computeNumberOfPages(this.config.getRelativeMemoryDriver()); // instantiate the serializer / comparator - final TypeSerializerFactory<T> serializerFactory = this.config.getInputSerializer(0, this.userCodeClassLoader); - final TypeComparatorFactory<T> sortingComparatorFactory = this.config.getDriverComparator(0, this.userCodeClassLoader); - final TypeComparatorFactory<T> groupingComparatorFactory = this.config.getDriverComparator(1, this.userCodeClassLoader); + final TypeSerializerFactory<IN> serializerFactory = this.config.getInputSerializer(0, this.userCodeClassLoader); + final TypeComparatorFactory<IN> sortingComparatorFactory = this.config.getDriverComparator(0, this.userCodeClassLoader); + final TypeComparatorFactory<IN> groupingComparatorFactory = this.config.getDriverComparator(1, this.userCodeClassLoader); this.serializer = serializerFactory.getSerializer(); this.sortingComparator = sortingComparatorFactory.createComparator(); this.groupingComparator = groupingComparatorFactory.createComparator(); @@ -111,9 +121,9 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { if (this.sortingComparator.supportsSerializationWithKeyNormalization() && this.serializer.getLength() > 0 && this.serializer.getLength() <= THRESHOLD_FOR_IN_PLACE_SORTING) { - this.sorter = new FixedLengthRecordSorter<T>(this.serializer, this.sortingComparator, memory); + this.sorter = new FixedLengthRecordSorter<IN>(this.serializer, this.sortingComparator, memory); } else { - this.sorter = new NormalizedKeySorter<T>(this.serializer, this.sortingComparator.duplicate(), memory); + this.sorter = new NormalizedKeySorter<IN>(this.serializer, this.sortingComparator.duplicate(), memory); } if (LOG.isDebugEnabled()) { @@ -149,7 +159,7 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { } @Override - public void collect(T record) { + public void collect(IN record) { // try writing to the sorter first try { if (this.sorter.write(record)) { @@ -190,18 +200,18 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { } private void sortAndCombine() throws Exception { - final InMemorySorter<T> sorter = this.sorter; + final InMemorySorter<IN> sorter = this.sorter; if (objectReuseEnabled) { if (!sorter.isEmpty()) { this.sortAlgo.sort(sorter); // run the combiner - final ReusingKeyGroupedIterator<T> keyIter = new ReusingKeyGroupedIterator<T>(sorter.getIterator(), this.serializer, this.groupingComparator); + final ReusingKeyGroupedIterator<IN> keyIter = new ReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.serializer, this.groupingComparator); // cache references on the stack - final FlatCombineFunction<T> stub = this.combiner; - final Collector<T> output = this.outputCollector; + final FlatCombineFunction<IN, OUT> stub = this.combiner; + final Collector<OUT> output = this.outputCollector; // run stub implementation while (this.running && keyIter.nextKey()) { @@ -212,12 +222,12 @@ public class SynchronousChainedCombineDriver<T> extends ChainedDriver<T, T> { if (!sorter.isEmpty()) { this.sortAlgo.sort(sorter); // run the combiner - final NonReusingKeyGroupedIterator<T> keyIter = new NonReusingKeyGroupedIterator<T>(sorter.getIterator(), this.groupingComparator); + final NonReusingKeyGroupedIterator<IN> keyIter = new NonReusingKeyGroupedIterator<IN>(sorter.getIterator(), this.groupingComparator); // cache references on the stack - final FlatCombineFunction<T> stub = this.combiner; - final Collector<T> output = this.outputCollector; + final FlatCombineFunction<IN, OUT> stub = this.combiner; + final Collector<OUT> output = this.outputCollector; // run stub implementation while (this.running && keyIter.nextKey()) { http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/main/java/org/apache/flink/runtime/operators/sort/CombiningUnilateralSortMerger.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/sort/CombiningUnilateralSortMerger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/sort/CombiningUnilateralSortMerger.java index d05bd9a..9282fd4 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/sort/CombiningUnilateralSortMerger.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/sort/CombiningUnilateralSortMerger.java @@ -70,7 +70,7 @@ public class CombiningUnilateralSortMerger<E> extends UnilateralSortMerger<E> { */ private static final Logger LOG = LoggerFactory.getLogger(CombiningUnilateralSortMerger.class); - private final FlatCombineFunction<E> combineStub; // the user code stub that does the combining + private final FlatCombineFunction<E, E> combineStub; // the user code stub that does the combining private Configuration udfConfig; @@ -100,7 +100,7 @@ public class CombiningUnilateralSortMerger<E> extends UnilateralSortMerger<E> { * @throws MemoryAllocationException Thrown, if not enough memory can be obtained from the memory manager to * perform the sort. */ - public CombiningUnilateralSortMerger(FlatCombineFunction<E> combineStub, MemoryManager memoryManager, IOManager ioManager, + public CombiningUnilateralSortMerger(FlatCombineFunction<E, E> combineStub, MemoryManager memoryManager, IOManager ioManager, MutableObjectIterator<E> input, AbstractInvokable parentTask, TypeSerializerFactory<E> serializerFactory, TypeComparator<E> comparator, double memoryFraction, int maxNumFileHandles, float startSpillingFraction) @@ -132,7 +132,7 @@ public class CombiningUnilateralSortMerger<E> extends UnilateralSortMerger<E> { * @throws MemoryAllocationException Thrown, if not enough memory can be obtained from the memory manager to * perform the sort. */ - public CombiningUnilateralSortMerger(FlatCombineFunction<E> combineStub, MemoryManager memoryManager, IOManager ioManager, + public CombiningUnilateralSortMerger(FlatCombineFunction<E, E> combineStub, MemoryManager memoryManager, IOManager ioManager, MutableObjectIterator<E> input, AbstractInvokable parentTask, TypeSerializerFactory<E> serializerFactory, TypeComparator<E> comparator, double memoryFraction, int numSortBuffers, int maxNumFileHandles, @@ -253,7 +253,7 @@ public class CombiningUnilateralSortMerger<E> extends UnilateralSortMerger<E> { // ------------------- Spilling Phase ------------------------ - final FlatCombineFunction<E> combineStub = CombiningUnilateralSortMerger.this.combineStub; + final FlatCombineFunction<E, E> combineStub = CombiningUnilateralSortMerger.this.combineStub; // now that we are actually spilling, take the combiner, and open it try { @@ -463,7 +463,7 @@ public class CombiningUnilateralSortMerger<E> extends UnilateralSortMerger<E> { this.memManager.getPageSize()); final WriterCollector<E> collector = new WriterCollector<E>(output, this.serializer); - final FlatCombineFunction<E> combineStub = CombiningUnilateralSortMerger.this.combineStub; + final FlatCombineFunction<E, E> combineStub = CombiningUnilateralSortMerger.this.combineStub; // combine and write to disk try { http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java index 5e2f3ca..d957fa1 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskExternalITCase.java @@ -68,7 +68,7 @@ public class CombineTaskExternalITCase extends DriverTestBase<RichGroupReduceFun getTaskConfig().setRelativeMemoryDriver(combine_frac); getTaskConfig().setFilehandlesDriver(2); - final GroupReduceCombineDriver<Record> testTask = new GroupReduceCombineDriver<Record>(); + final GroupReduceCombineDriver<Record, Record> testTask = new GroupReduceCombineDriver<Record, Record>(); try { testDriver(testTask, MockCombiningReduceStub.class); @@ -122,7 +122,7 @@ public class CombineTaskExternalITCase extends DriverTestBase<RichGroupReduceFun getTaskConfig().setRelativeMemoryDriver(combine_frac); getTaskConfig().setFilehandlesDriver(2); - final GroupReduceCombineDriver<Record> testTask = new GroupReduceCombineDriver<Record>(); + final GroupReduceCombineDriver<Record, Record> testTask = new GroupReduceCombineDriver<Record, Record>(); try { testDriver(testTask, MockCombiningReduceStub.class); http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java index 515ce76..3d9e991 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/CombineTaskTest.java @@ -70,7 +70,7 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco getTaskConfig().setRelativeMemoryDriver(combine_frac); getTaskConfig().setFilehandlesDriver(2); - final GroupReduceCombineDriver<Record> testTask = new GroupReduceCombineDriver<Record>(); + final GroupReduceCombineDriver<Record, Record> testTask = new GroupReduceCombineDriver<Record, Record>(); try { testDriver(testTask, MockCombiningReduceStub.class); @@ -107,7 +107,7 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco getTaskConfig().setRelativeMemoryDriver(combine_frac); getTaskConfig().setFilehandlesDriver(2); - final GroupReduceCombineDriver<Record> testTask = new GroupReduceCombineDriver<Record>(); + final GroupReduceCombineDriver<Record, Record> testTask = new GroupReduceCombineDriver<Record, Record>(); try { testDriver(testTask, MockFailingCombiningReduceStub.class); @@ -132,7 +132,7 @@ public class CombineTaskTest extends DriverTestBase<RichGroupReduceFunction<Reco getTaskConfig().setRelativeMemoryDriver(combine_frac); getTaskConfig().setFilehandlesDriver(2); - final GroupReduceCombineDriver<Record> testTask = new GroupReduceCombineDriver<Record>(); + final GroupReduceCombineDriver<Record, Record> testTask = new GroupReduceCombineDriver<Record, Record>(); final AtomicBoolean success = new AtomicBoolean(false); http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index 61bfbfa..00761ec 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -627,6 +627,62 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { } /** + * Applies a CombineFunction on a grouped [[DataSet]]. A + * CombineFunction is similar to a GroupReduceFunction but does not + * perform a full data exchange. Instead, the CombineFunction calls + * the combine method once per partition for combining a group of + * results. This operator is suitable for combining values into an + * intermediate format before doing a proper groupReduce where the + * data is shuffled across the node for further reduction. The + * GroupReduce operator can also be supplied with a combiner by + * implementing the RichGroupReduce function. The combine method of + * the RichGroupReduce function demands input and output type to be + * the same. The CombineFunction, on the other side, can have an + * arbitrary output type. + */ + def combineGroup[R: TypeInformation: ClassTag]( + combiner: FlatCombineFunction[T, R]): DataSet[R] = { + if (combiner == null) { + throw new NullPointerException("Combine function must not be null.") + } + wrap(new GroupCombineOperator[T, R](javaSet, + implicitly[TypeInformation[R]], + combiner, + getCallLocationName())) + } + + /** + * Applies a CombineFunction on a grouped [[DataSet]]. A + * CombineFunction is similar to a GroupReduceFunction but does not + * perform a full data exchange. Instead, the CombineFunction calls + * the combine method once per partition for combining a group of + * results. This operator is suitable for combining values into an + * intermediate format before doing a proper groupReduce where the + * data is shuffled across the node for further reduction. The + * GroupReduce operator can also be supplied with a combiner by + * implementing the RichGroupReduce function. The combine method of + * the RichGroupReduce function demands input and output type to be + * the same. The CombineFunction, on the other side, can have an + * arbitrary output type. + */ + def combineGroup[R: TypeInformation: ClassTag]( + fun: (Iterator[T], Collector[R]) => Unit): DataSet[R] = { + if (fun == null) { + throw new NullPointerException("Combine function must not be null.") + } + val combiner = new FlatCombineFunction[T, R] { + val cleanFun = clean(fun) + def combine(in: java.lang.Iterable[T], out: Collector[R]) { + cleanFun(in.iterator().asScala, out) + } + } + wrap(new GroupCombineOperator[T, R](javaSet, + implicitly[TypeInformation[R]], + combiner, + getCallLocationName())) + } + + /** * Creates a new DataSet containing the first `n` elements of this DataSet. */ def first(n: Int): DataSet[T] = { http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala index 7ac8dcd..eca4563 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/GroupedDataSet.scala @@ -22,7 +22,7 @@ import org.apache.flink.api.java.functions.{KeySelector, FirstReducer} import org.apache.flink.api.scala.operators.ScalaAggregateOperator import scala.collection.JavaConverters._ import org.apache.commons.lang3.Validate -import org.apache.flink.api.common.functions.{GroupReduceFunction, ReduceFunction} +import org.apache.flink.api.common.functions.{FlatCombineFunction, GroupReduceFunction, ReduceFunction, Partitioner} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.java.aggregation.Aggregations import org.apache.flink.api.java.operators._ @@ -30,7 +30,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.util.Collector import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.flink.api.common.functions.Partitioner import com.google.common.base.Preconditions /** @@ -355,6 +354,56 @@ class GroupedDataSet[T: ClassTag]( } /** + * Applies a CombineFunction on a grouped [[DataSet]]. A + * CombineFunction is similar to a GroupReduceFunction but does not + * perform a full data exchange. Instead, the CombineFunction calls + * the combine method once per partition for combining a group of + * results. This operator is suitable for combining values into an + * intermediate format before doing a proper groupReduce where the + * data is shuffled across the node for further reduction. The + * GroupReduce operator can also be supplied with a combiner by + * implementing the RichGroupReduce function. The combine method of + * the RichGroupReduce function demands input and output type to be + * the same. The CombineFunction, on the other side, can have an + * arbitrary output type. + */ + def combineGroup[R: TypeInformation: ClassTag]( + fun: (Iterator[T], Collector[R]) => Unit): DataSet[R] = { + Validate.notNull(fun, "GroupCombine function must not be null.") + val combiner = new FlatCombineFunction[T, R] { + val cleanFun = set.clean(fun) + def combine(in: java.lang.Iterable[T], out: Collector[R]) { + cleanFun(in.iterator().asScala, out) + } + } + wrap( + new GroupCombineOperator[T, R](maybeCreateSortedGrouping(), + implicitly[TypeInformation[R]], combiner, getCallLocationName())) + } + + /** + * Applies a CombineFunction on a grouped [[DataSet]]. A + * CombineFunction is similar to a GroupReduceFunction but does not + * perform a full data exchange. Instead, the CombineFunction calls + * the combine method once per partition for combining a group of + * results. This operator is suitable for combining values into an + * intermediate format before doing a proper groupReduce where the + * data is shuffled across the node for further reduction. The + * GroupReduce operator can also be supplied with a combiner by + * implementing the RichGroupReduce function. The combine method of + * the RichGroupReduce function demands input and output type to be + * the same. The CombineFunction, on the other side, can have an + * arbitrary output type. + */ + def combineGroup[R: TypeInformation: ClassTag]( + combiner: FlatCombineFunction[T, R]): DataSet[R] = { + Validate.notNull(combiner, "GroupCombine function must not be null.") + wrap( + new GroupCombineOperator[T, R](maybeCreateSortedGrouping(), + implicitly[TypeInformation[R]], combiner, getCallLocationName())) + } + + /** * Creates a new DataSet containing the first `n` elements of each group of this DataSet. */ def first(n: Int): DataSet[T] = { http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/CustomRankCombiner.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/CustomRankCombiner.java b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/CustomRankCombiner.java index 8de861a..6631f07 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/CustomRankCombiner.java +++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/nephele/customdanglingpagerank/CustomRankCombiner.java @@ -28,7 +28,7 @@ import org.apache.flink.util.Collector; public class CustomRankCombiner extends AbstractRichFunction implements GroupReduceFunction<VertexWithRank, VertexWithRank>, - FlatCombineFunction<VertexWithRank> + FlatCombineFunction<VertexWithRank, VertexWithRank> { private static final long serialVersionUID = 1L; http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupCombineITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupCombineITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupCombineITCase.java new file mode 100644 index 0000000..0808389 --- /dev/null +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupCombineITCase.java @@ -0,0 +1,522 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.flink.test.javaApiOperators; + +import org.apache.flink.api.common.functions.FlatCombineFunction; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.operators.UnsortedGrouping; +import org.apache.flink.api.java.tuple.Tuple1; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; +import org.apache.flink.test.util.MultipleProgramsTestBase; +import org.apache.flink.util.Collector; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + + +@SuppressWarnings("serial") +@RunWith(Parameterized.class) +/** +* The GroupCombine operator is not easy to test because it is essentially just a combiner. The result can be +* the result of a normal groupReduce at any stage its execution. The basic idea is to preserve the grouping key +* in the partial result, so that we can do a reduceGroup afterwards to finalize the results for verification. +* In addition, we can use hashPartition to partition the data and check if no shuffling (just combining) has +* been performed. +*/ +public class GroupCombineITCase extends MultipleProgramsTestBase { + + public GroupCombineITCase(ExecutionMode mode) { + super(mode); + } + + private String resultPath; + + private String expected; + + private static String identityResult = "1,1,Hi\n" + + "2,2,Hello\n" + + "3,2,Hello world\n" + + "4,3,Hello world, how are you?\n" + + "5,3,I am fine.\n" + + "6,3,Luke Skywalker\n" + + "7,4,Comment#1\n" + + "8,4,Comment#2\n" + + "9,4,Comment#3\n" + + "10,4,Comment#4\n" + + "11,5,Comment#5\n" + + "12,5,Comment#6\n" + + "13,5,Comment#7\n" + + "14,5,Comment#8\n" + + "15,5,Comment#9\n" + + "16,6,Comment#10\n" + + "17,6,Comment#11\n" + + "18,6,Comment#12\n" + + "19,6,Comment#13\n" + + "20,6,Comment#14\n" + + "21,6,Comment#15\n"; + + @Rule + public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Before + public void before() throws Exception { + resultPath = tempFolder.newFile().toURI().toString(); + } + + @After + public void after() throws Exception { + if (expected != null) { + compareResultsByLinesInMemory(expected, resultPath); + } + } + + @Test + public void testAllGroupCombineIdentity() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + + DataSet<Tuple3<Integer, Long, String>> reduceDs = ds + // combine + .combineGroup(new IdentityFunction()) + // fully reduce + .reduceGroup(new IdentityFunction()); + + + reduceDs.writeAsCsv(resultPath); + + env.execute(); + + expected = identityResult; + } + + @Test + public void testIdentity() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + DataSet<Tuple3<Integer, Long, String>> reduceDs = ds + // combine + .combineGroup(new IdentityFunction()) + // fully reduce + .reduceGroup(new IdentityFunction()); + + reduceDs.writeAsCsv(resultPath); + + env.execute(); + + expected = identityResult; + } + + @Test + public void testIdentityWithGroupBy() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + DataSet<Tuple3<Integer, Long, String>> reduceDs = ds + .groupBy(1) + // combine + .combineGroup(new IdentityFunction()) + // fully reduce + .reduceGroup(new IdentityFunction()); + + + reduceDs.writeAsCsv(resultPath); + + env.execute(); + + expected = identityResult; + } + + @Test + public void testIdentityWithGroupByAndSort() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + DataSet<Tuple3<Integer, Long, String>> reduceDs = ds + .groupBy(1) + .sortGroup(1, Order.DESCENDING) + // reduce partially + .combineGroup(new IdentityFunction()) + .groupBy(1) + .sortGroup(1, Order.DESCENDING) + // fully reduce + .reduceGroup(new IdentityFunction()); + + reduceDs.writeAsCsv(resultPath); + + env.execute(); + + expected = identityResult; + } + + @Test + public void testPartialReduceWithIdenticalInputOutputType() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // data + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + DataSet<Tuple2<Long, Tuple3<Integer, Long, String>>> dsWrapped = ds + // wrap values as Kv pairs with the grouping key as key + .map(new Tuple3KvWrapper()); + + dsWrapped + .groupBy(0) + // reduce partially + .combineGroup(new Tuple3toTuple3GroupReduce()) + .groupBy(0) + // reduce fully to check result + .reduceGroup(new Tuple3toTuple3GroupReduce()) + //unwrap + .map(new MapFunction<Tuple2<Long, Tuple3<Integer, Long, String>>, Tuple3<Integer, Long, String>>() { + @Override + public Tuple3<Integer, Long, String> map(Tuple2<Long, Tuple3<Integer, Long, String>> value) throws Exception { + return value.f1; + } + }) + .writeAsCsv(resultPath); + + + + env.execute(); + + expected = "1,1,combined\n" + + "5,4,combined\n" + + "15,9,combined\n" + + "34,16,combined\n" + + "65,25,combined\n" + + "111,36,combined\n"; + } + + @Test + public void testPartialReduceWithDifferentInputOutputType() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // data + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + DataSet<Tuple2<Long, Tuple3<Integer, Long, String>>> dsWrapped = ds + // wrap values as Kv pairs with the grouping key as key + .map(new Tuple3KvWrapper()); + + dsWrapped + .groupBy(0) + // reduce partially + .combineGroup(new Tuple3toTuple2GroupReduce()) + .groupBy(0) + // reduce fully to check result + .reduceGroup(new Tuple2toTuple2GroupReduce()) + //unwrap + .map(new MapFunction<Tuple2<Long,Tuple2<Integer,Long>>, Tuple2<Integer,Long>>() { + @Override + public Tuple2<Integer, Long> map(Tuple2<Long, Tuple2<Integer, Long>> value) throws Exception { + return value.f1; + } + }) + .writeAsCsv(resultPath); + + + + env.execute(); + + expected = "1,3\n" + + "5,20\n" + + "15,58\n" + + "34,52\n" + + "65,70\n" + + "111,96\n"; + + } + + @Test + // check if no shuffle is being executed + public void testCheckPartitionShuffleGroupBy() throws Exception { + + org.junit.Assume.assumeTrue(mode != ExecutionMode.COLLECTION); + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + // data + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + // partition and group data + UnsortedGrouping<Tuple3<Integer, Long, String>> partitionedDS = ds.partitionByHash(0).groupBy(1); + + partitionedDS.combineGroup(new FlatCombineFunction<Tuple3<Integer, Long, String>, Tuple2<Long, Integer>>() { + @Override + public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple2<Long, Integer>> out) throws Exception { + int count = 0; + long key = 0; + for (Tuple3<Integer, Long, String> value : values) { + key = value.f1; + count++; + } + out.collect(new Tuple2(key, count)); + } + }).writeAsCsv(resultPath); + + env.execute(); + + String notExpected = "6,6\n" + + "5,5\n" + + "4,4\n" + + "3,3\n" + + "2,2\n" + + "1,1\n"; + + // check + + ArrayList<String> list = new ArrayList<String>(); + readAllResultLines(list, resultPath); + + String[] result = list.toArray(new String[list.size()]); + Arrays.sort(result); + + String[] expected = notExpected.split("\n"); + Arrays.sort(expected); + + Assert.assertEquals("The two arrays were identical.", false, Arrays.equals(expected, result)); + } + + @Test + // check if dop 1 results in the same data like a shuffle + public void testCheckPartitionShuffleDOP1() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + env.setDegreeOfParallelism(1); + + // data + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + + // partition and group data + UnsortedGrouping<Tuple3<Integer, Long, String>> partitionedDS = ds.partitionByHash(0).groupBy(1); + + partitionedDS.combineGroup(new FlatCombineFunction<Tuple3<Integer, Long, String>, Tuple2<Long, Integer>>() { + @Override + public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple2<Long, Integer>> out) throws Exception { + int count = 0; + long key = 0; + for (Tuple3<Integer, Long, String> value : values) { + key = value.f1; + count++; + } + out.collect(new Tuple2(key, count)); + } + }).writeAsCsv(resultPath); + + env.execute(); + + expected = "6,6\n" + + "5,5\n" + + "4,4\n" + + "3,3\n" + + "2,2\n" + + "1,1\n"; + + } + + @Test + // check if all API methods are callable + public void testAPI() throws Exception { + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple1<String>> ds = CollectionDataSets.getStringDataSet(env).map(new MapFunction<String, Tuple1<String>>() { + @Override + public Tuple1<String> map(String value) throws Exception { + return new Tuple1<String>(value); + } + }); + + // all methods on DataSet + ds.combineGroup(new FlatCombineFunctionExample()) + .output(new DiscardingOutputFormat<Tuple1<String>>()); + + // all methods on UnsortedGrouping + ds.groupBy(0).combineGroup(new FlatCombineFunctionExample()) + .output(new DiscardingOutputFormat<Tuple1<String>>()); + + // all methods on SortedGrouping + ds.groupBy(0).sortGroup(0, Order.ASCENDING).combineGroup(new FlatCombineFunctionExample()) + .output(new DiscardingOutputFormat<Tuple1<String>>()); + + env.execute(); + } + + public static class FlatCombineFunctionExample implements FlatCombineFunction<Tuple1<String>, Tuple1<String>> { + + @Override + public void combine(Iterable<Tuple1<String>> values, Collector<Tuple1<String>> out) throws Exception { + for (Tuple1<String> value : values) { + out.collect(value); + } + } + } + + public static class ScalaFlatCombineFunctionExample implements FlatCombineFunction<scala.Tuple1<String>, scala.Tuple1<String>> { + + @Override + public void combine(Iterable<scala.Tuple1<String>> values, Collector<scala.Tuple1<String>> out) throws Exception { + for (scala.Tuple1<String> value : values) { + out.collect(value); + } + } + } + + public static class IdentityFunction implements FlatCombineFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>>, + GroupReduceFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> { + + @Override + public void combine(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) throws Exception { + for (Tuple3<Integer, Long, String> value : values) { + out.collect(new Tuple3<Integer, Long, String>(value.f0, value.f1, value.f2)); + } + } + + @Override + public void reduce(Iterable<Tuple3<Integer, Long, String>> values, Collector<Tuple3<Integer, Long, String>> out) throws Exception { + for (Tuple3<Integer, Long, String> value : values) { + out.collect(new Tuple3<Integer, Long, String>(value.f0, value.f1, value.f2)); + } + } + } + + + public static class Tuple3toTuple3GroupReduce implements KvGroupReduce<Long, Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> { + + public void combine(Iterable<Tuple2<Long, Tuple3<Integer, Long, String>>> values, Collector<Tuple2<Long, Tuple3<Integer, Long, String>>> out) throws Exception { + int i = 0; + long l = 0; + long key = 0; + + // collapse groups + for (Tuple2<Long, Tuple3<Integer, Long, String>> value : values) { + key = value.f0; + Tuple3<Integer, Long, String> extracted = value.f1; + i += extracted.f0; + l += extracted.f1; + } + + Tuple3<Integer, Long, String> result = new Tuple3<Integer, Long, String>(i, l, "combined"); + out.collect(new Tuple2<Long, Tuple3<Integer, Long, String>>(key, result)); + } + + @Override + public void reduce(Iterable values, Collector out) throws Exception { + combine(values, out); + } + } + + public static class Tuple3toTuple2GroupReduce implements KvGroupReduce<Long, Tuple3<Integer, Long, String>, Tuple2<Integer, Long>, Tuple2<Integer, Long>> { + + @Override + public void combine(Iterable<Tuple2<Long, Tuple3<Integer, Long, String>>> values, Collector<Tuple2<Long, Tuple2<Integer, Long>>> out) throws Exception { + int i = 0; + long l = 0; + long key = 0; + + // collapse groups + for (Tuple2<Long, Tuple3<Integer, Long, String>> value : values) { + key = value.f0; + Tuple3<Integer, Long, String> extracted = value.f1; + i += extracted.f0; + l += extracted.f1 + extracted.f2.length(); + } + + Tuple2<Integer, Long> result = new Tuple2<Integer, Long>(i, l); + out.collect(new Tuple2<Long, Tuple2<Integer, Long>>(key, result)); + } + + @Override + public void reduce(Iterable<Tuple2<Long, Tuple2<Integer, Long>>> values, Collector<Tuple2<Long, Tuple2<Integer, Long>>> out) throws Exception { + new Tuple2toTuple2GroupReduce().reduce(values, out); + } + } + + public static class Tuple2toTuple2GroupReduce implements KvGroupReduce<Long, Tuple2<Integer, Long>, Tuple2<Integer, Long>, Tuple2<Integer, Long>> { + + public void combine(Iterable<Tuple2<Long, Tuple2<Integer, Long>>> values, Collector<Tuple2<Long, Tuple2<Integer, Long>>> out) throws Exception { + int i = 0; + long l = 0; + long key = 0; + + // collapse groups + for (Tuple2<Long, Tuple2<Integer, Long>> value : values) { + key = value.f0; + Tuple2<Integer, Long> extracted = value.f1; + i += extracted.f0; + l += extracted.f1; + } + + Tuple2<Integer, Long> result = new Tuple2<Integer, Long>(i, l); + + out.collect(new Tuple2<Long, Tuple2<Integer, Long>>(key, result)); + } + + @Override + public void reduce(Iterable<Tuple2<Long, Tuple2<Integer, Long>>> values, Collector<Tuple2<Long, Tuple2<Integer, Long>>> out) throws Exception { + combine(values, out); + } + } + + public class Tuple3KvWrapper implements MapFunction<Tuple3<Integer, Long, String>, Tuple2<Long, Tuple3<Integer, Long, String>>> { + @Override + public Tuple2<Long, Tuple3<Integer, Long, String>> map(Tuple3<Integer, Long, String> value) throws Exception { + return new Tuple2<Long,Tuple3<Integer, Long, String>>(value.f1, value); + } + } + + + public interface CombineAndReduceGroup <IN, INT, OUT> extends FlatCombineFunction<IN, INT>, GroupReduceFunction<INT, OUT> { + } + + public interface KvGroupReduce<K, V, INT, OUT> extends CombineAndReduceGroup<Tuple2<K, V>, Tuple2<K, INT>, Tuple2<K, OUT>> { + } + + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/flink/blob/e93e0cb8/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupCombineITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupCombineITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupCombineITCase.scala new file mode 100644 index 0000000..05346ba --- /dev/null +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/GroupCombineITCase.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.api.scala.operators + +import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.java.io.DiscardingOutputFormat +import org.apache.flink.api.scala.ExecutionEnvironment +import org.apache.flink.api.scala.util.CollectionDataSets +import org.apache.flink.test.javaApiOperators.GroupCombineITCase +import org.apache.flink.test.util.MultipleProgramsTestBase.ExecutionMode +import org.apache.flink.test.util.{MultipleProgramsTestBase} +import org.apache.flink.util.Collector +import org.junit._ +import org.junit.rules.TemporaryFolder +import org.junit.runner.RunWith +import org.junit.runners.Parameterized +import org.apache.flink.api.scala._ + +/** + * Java interoperability tests. Main tests are in GroupCombineITCase Java. + */ +@RunWith(classOf[Parameterized]) +class GroupCombineITCase(mode: ExecutionMode) extends MultipleProgramsTestBase(mode) { + + @Test + def testApi(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + + val ds: DataSet[Tuple1[String]] = CollectionDataSets.getStringDataSet(env) + .map(str => Tuple1(str)) + + // all methods on DataSet + ds.combineGroup(new GroupCombineITCase.ScalaFlatCombineFunctionExample()) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + ds.combineGroup((in, out: Collector[Tuple1[String]]) => in.toSet foreach (out.collect)) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + // all methods on UnsortedGrouping + ds.groupBy(0) + .combineGroup(new GroupCombineITCase.ScalaFlatCombineFunctionExample()) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + ds.groupBy(0) + .combineGroup((in, out: Collector[Tuple1[String]]) => in.toSet foreach (out.collect)) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + // all methods on SortedGrouping + ds.groupBy(0).sortGroup(0, Order.ASCENDING) + .combineGroup(new GroupCombineITCase.ScalaFlatCombineFunctionExample()) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + ds.groupBy(0).sortGroup(0, Order.ASCENDING) + .combineGroup((in, out: Collector[Tuple1[String]]) => in.toSet foreach (out.collect)) + .output(new DiscardingOutputFormat[Tuple1[String]]) + + env.execute + } + +} + +