[FLINK-703] [java api] Use complete element as join key
Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/30a74c76 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/30a74c76 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/30a74c76 Branch: refs/heads/master Commit: 30a74c768b88043c173053f5d37bfb0fde85f149 Parents: e1618e2 Author: Chiwan Park <chiwanp...@icloud.com> Authored: Mon Apr 6 03:18:23 2015 +0900 Committer: Fabian Hueske <fhue...@apache.org> Committed: Tue Apr 21 19:03:15 2015 +0200 ---------------------------------------------------------------------- .../operators/base/CoGroupOperatorBase.java | 10 ++- .../base/GroupCombineOperatorBase.java | 42 +++++---- .../operators/base/GroupReduceOperatorBase.java | 41 +++++---- .../apache/flink/api/java/operators/Keys.java | 45 ++++++---- .../api/java/operator/CoGroupOperatorTest.java | 94 +++++++++++++++++--- .../flink/api/java/operator/GroupingTest.java | 34 ++++++- .../api/java/operator/JoinOperatorTest.java | 72 +++++++++++++++ .../optimizer/postpass/JavaApiPostPass.java | 5 -- .../test/javaApiOperators/CoGroupITCase.java | 82 ++++++++++++++++- .../javaApiOperators/GroupReduceITCase.java | 38 ++++++-- .../flink/test/javaApiOperators/JoinITCase.java | 31 +++++++ 11 files changed, 406 insertions(+), 88 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java index dbebeb4..7be5650 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/CoGroupOperatorBase.java @@ -32,6 +32,7 @@ import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.GenericPairComparator; @@ -273,12 +274,15 @@ public class CoGroupOperatorBase<IN1, IN2, OUT, FT extends CoGroupFunction<IN1, return result; } + @SuppressWarnings("unchecked") private <T> TypeComparator<T> getTypeComparator(ExecutionConfig executionConfig, TypeInformation<T> inputType, int[] inputKeys, boolean[] inputSortDirections) { - if (!(inputType instanceof CompositeType)) { - throw new InvalidProgramException("Input types of coGroup must be composite types."); + if (inputType instanceof CompositeType) { + return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); + } else if (inputType instanceof AtomicType) { + return ((AtomicType<T>) inputType).createComparator(inputSortDirections[0], executionConfig); } - return ((CompositeType<T>) inputType).createComparator(inputKeys, inputSortDirections, 0, executionConfig); + throw new InvalidProgramException("Input type of coGroup must be one of composite types or atomic types."); } private static class CoGroupSortListIterator<IN1, IN2> { http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java index 27fbc1c..c7ba92b 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupCombineOperatorBase.java @@ -31,6 +31,7 @@ import org.apache.flink.api.common.operators.SingleInputOperator; import org.apache.flink.api.common.operators.UnaryOperatorInformation; import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -77,6 +78,16 @@ public class GroupCombineOperatorBase<IN, OUT, FT extends GroupCombineFunction<I return this.groupOrder; } + private TypeComparator<IN> getTypeComparator(TypeInformation<IN> typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) { + if (typeInfo instanceof CompositeType) { + return ((CompositeType<IN>) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig); + } else if (typeInfo instanceof AtomicType) { + return ((AtomicType<IN>) typeInfo).createComparator(sortOrderings[0], executionConfig); + } + + throw new InvalidProgramException("Input type of GroupCombine must be one of composite types or atomic types."); + } + // -------------------------------------------------------------------------------------------- @Override @@ -87,11 +98,6 @@ public class GroupCombineOperatorBase<IN, OUT, FT extends GroupCombineFunction<I TypeInformation<IN> inputType = operatorInfo.getInputType(); int[] keyColumns = getKeyColumns(0); - - if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) { - throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type."); - } - int[] sortColumns = keyColumns; boolean[] sortOrderings = new boolean[sortColumns.length]; @@ -100,19 +106,17 @@ public class GroupCombineOperatorBase<IN, OUT, FT extends GroupCombineFunction<I sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections()); } - if (inputType instanceof CompositeType) { - if(sortColumns.length == 0) { // => all reduce. No comparator - Preconditions.checkArgument(sortOrderings.length == 0); - } else { - final TypeComparator<IN> sortComparator = ((CompositeType<IN>) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig); - - Collections.sort(inputData, new Comparator<IN>() { - @Override - public int compare(IN o1, IN o2) { - return sortComparator.compare(o1, o2); - } - }); - } + if(sortColumns.length == 0) { // => all reduce. No comparator + Preconditions.checkArgument(sortOrderings.length == 0); + } else { + final TypeComparator<IN> sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig); + + Collections.sort(inputData, new Comparator<IN>() { + @Override + public int compare(IN o1, IN o2) { + return sortComparator.compare(o1, o2); + } + }); } FunctionUtils.setFunctionRuntimeContext(function, ctx); @@ -133,7 +137,7 @@ public class GroupCombineOperatorBase<IN, OUT, FT extends GroupCombineFunction<I } else { final TypeSerializer<IN> inputSerializer = inputType.createSerializer(executionConfig); boolean[] keyOrderings = new boolean[keyColumns.length]; - final TypeComparator<IN> comparator = ((CompositeType<IN>) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig); + final TypeComparator<IN> comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig); ListKeyGroupedIterator<IN> keyedIterator = new ListKeyGroupedIterator<IN>(inputData, inputSerializer, comparator); http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java ---------------------------------------------------------------------- diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java index 57f07f3..3056fe7 100644 --- a/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java +++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/base/GroupReduceOperatorBase.java @@ -34,6 +34,7 @@ import org.apache.flink.api.common.operators.util.ListKeyGroupedIterator; import org.apache.flink.api.common.operators.util.UserCodeClassWrapper; import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper; import org.apache.flink.api.common.operators.util.UserCodeWrapper; +import org.apache.flink.api.common.typeinfo.AtomicType; import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeutils.CompositeType; import org.apache.flink.api.common.typeutils.TypeComparator; @@ -148,6 +149,16 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN, public Partitioner<?> getCustomPartitioner() { return customPartitioner; } + + private TypeComparator<IN> getTypeComparator(TypeInformation<IN> typeInfo, int[] sortColumns, boolean[] sortOrderings, ExecutionConfig executionConfig) { + if (typeInfo instanceof CompositeType) { + return ((CompositeType<IN>) typeInfo).createComparator(sortColumns, sortOrderings, 0, executionConfig); + } else if (typeInfo instanceof AtomicType) { + return ((AtomicType<IN>) typeInfo).createComparator(sortOrderings[0], executionConfig); + } + + throw new InvalidProgramException("Input type of GroupReduce must be one of composite types or atomic types."); + } // -------------------------------------------------------------------------------------------- @@ -159,11 +170,6 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN, TypeInformation<IN> inputType = operatorInfo.getInputType(); int[] keyColumns = getKeyColumns(0); - - if (!(inputType instanceof CompositeType) && (keyColumns.length > 0 || groupOrder != null)) { - throw new InvalidProgramException("Grouping or group-sorting is only possible on composite type."); - } - int[] sortColumns = keyColumns; boolean[] sortOrderings = new boolean[sortColumns.length]; @@ -172,19 +178,16 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN, sortOrderings = ArrayUtils.addAll(sortOrderings, groupOrder.getFieldSortDirections()); } - if (inputType instanceof CompositeType) { - if(sortColumns.length == 0) { // => all reduce. No comparator - Preconditions.checkArgument(sortOrderings.length == 0); - } else { - final TypeComparator<IN> sortComparator = ((CompositeType<IN>) inputType).createComparator(sortColumns, sortOrderings, 0, executionConfig); - - Collections.sort(inputData, new Comparator<IN>() { - @Override - public int compare(IN o1, IN o2) { - return sortComparator.compare(o1, o2); - } - }); - } + if(sortColumns.length == 0) { // => all reduce. No comparator + Preconditions.checkArgument(sortOrderings.length == 0); + } else { + final TypeComparator<IN> sortComparator = getTypeComparator(inputType, sortColumns, sortOrderings, executionConfig); + Collections.sort(inputData, new Comparator<IN>() { + @Override + public int compare(IN o1, IN o2) { + return sortComparator.compare(o1, o2); + } + }); } FunctionUtils.setFunctionRuntimeContext(function, ctx); @@ -205,7 +208,7 @@ public class GroupReduceOperatorBase<IN, OUT, FT extends GroupReduceFunction<IN, } else { final TypeSerializer<IN> inputSerializer = inputType.createSerializer(executionConfig); boolean[] keyOrderings = new boolean[keyColumns.length]; - final TypeComparator<IN> comparator = ((CompositeType<IN>) inputType).createComparator(keyColumns, keyOrderings, 0, executionConfig); + final TypeComparator<IN> comparator = getTypeComparator(inputType, keyColumns, keyOrderings, executionConfig); ListKeyGroupedIterator<IN> keyedIterator = new ListKeyGroupedIterator<IN>(inputData, inputSerializer, comparator); http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java index a2cde07..ee233e8 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/Keys.java @@ -274,24 +274,33 @@ public abstract class Keys<T> { * Create ExpressionKeys from String-expressions */ public ExpressionKeys(String[] expressionsIn, TypeInformation<T> type) { - if(!(type instanceof CompositeType<?>)) { - throw new IllegalArgumentException("Key expressions are only supported on POJO types and Tuples. " - + "A type is considered a POJO if all its fields are public, or have both getters and setters defined"); - } - CompositeType<T> cType = (CompositeType<T>) type; - - String[] expressions = removeDuplicates(expressionsIn); - if(expressionsIn.length != expressions.length) { - LOG.warn("The key expressions contained duplicates. They are now unique"); - } - // extract the keys on their flat position - keyFields = new ArrayList<FlatFieldDescriptor>(expressions.length); - for (int i = 0; i < expressions.length; i++) { - List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check - if(keys.size() == 0) { - throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); + Preconditions.checkNotNull(expressionsIn, "Field expression cannot be null."); + + if (type instanceof AtomicType) { + if (!type.isKeyType()) { + throw new InvalidProgramException("This type (" + type + ") cannot be used as key."); + } else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) { + throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'."); + } + + keyFields = new ArrayList<FlatFieldDescriptor>(1); + keyFields.add(new FlatFieldDescriptor(0, type)); + } else { + CompositeType<T> cType = (CompositeType<T>) type; + + String[] expressions = removeDuplicates(expressionsIn); + if(expressionsIn.length != expressions.length) { + LOG.warn("The key expressions contained duplicates. They are now unique"); + } + // extract the keys on their flat position + keyFields = new ArrayList<FlatFieldDescriptor>(expressions.length); + for (int i = 0; i < expressions.length; i++) { + List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check + if(keys.size() == 0) { + throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType); + } + keyFields.addAll(keys); } - keyFields.addAll(keys); } } @@ -410,7 +419,7 @@ public abstract class Keys<T> { return Arrays.copyOfRange(fields, 0, k+1); } } - + public static class IncompatibleKeysException extends Exception { private static final long serialVersionUID = 1L; public static final String SIZE_MISMATCH_MESSAGE = "The number of specified keys is different."; http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java index 60754e6..f32f6a9 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/CoGroupOperatorTest.java @@ -18,26 +18,26 @@ package org.apache.flink.api.java.operator; -import java.util.ArrayList; -import java.util.List; - +import org.apache.flink.api.common.InvalidProgramException; import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.operators.SemanticProperties; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; import org.apache.flink.api.java.functions.FunctionAnnotation; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType; import org.apache.flink.api.java.operators.CoGroupOperator; import org.apache.flink.api.java.tuple.Tuple2; -import org.apache.flink.util.Collector; -import org.junit.Assert; -import org.apache.flink.api.common.InvalidProgramException; -import org.apache.flink.api.common.typeinfo.BasicTypeInfo; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple5; import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.util.Collector; +import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.operator.JoinOperatorTest.CustomType; + +import java.util.ArrayList; +import java.util.List; import static org.junit.Assert.assertTrue; @@ -181,6 +181,78 @@ public class CoGroupOperatorTest { // should not work, cogroup key non-existent ds1.coGroup(ds2).where("myNonExistent").equalTo("myInt"); } + + @Test + public void testCoGroupKeyAtomicExpression1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<CustomType> ds1 = env.fromCollection(customTypeData); + DataSet<Integer> ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("*"); + } + + @Test + public void testCoGroupKeyAtomicExpression2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 1); + DataSet<CustomType> ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("*").equalTo("myInt"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 1); + DataSet<CustomType> ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 1); + DataSet<CustomType> ds2 = env.fromCollection(customTypeData); + + ds1.coGroup(ds2).where("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<CustomType> ds1 = env.fromCollection(customTypeData); + DataSet<Integer> ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression4() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<CustomType> ds1 = env.fromCollection(customTypeData); + DataSet<Integer> ds2 = env.fromElements(0, 0, 1); + + ds1.coGroup(ds2).where("myInt").equalTo("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression5() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<ArrayList<Integer>> ds1 = env.fromElements(new ArrayList<Integer>()); + DataSet<Integer> ds2 = env.fromElements(0, 0, 0); + + ds1.coGroup(ds2).where("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testCoGroupKeyAtomicInvalidExpression6() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 0); + DataSet<ArrayList<Integer>> ds2 = env.fromElements(new ArrayList<Integer>()); + + ds1.coGroup(ds2).where("*").equalTo("*"); + } @Test public void testCoGroupKeyExpressions1Nested() { http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java index 314695f..b3922b3 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/GroupingTest.java @@ -143,7 +143,7 @@ public class GroupingTest { } } - @Test(expected = IllegalArgumentException.class) + @Test(expected = InvalidProgramException.class) public void testGroupByKeyExpressions2() { final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); @@ -551,6 +551,38 @@ public class GroupingTest { }, Order.ASCENDING); } + @Test + public void testGroupingAtomicType() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> dataSet = env.fromElements(0, 1, 1, 2, 0, 0); + + dataSet.groupBy("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> dataSet = env.fromElements(0, 1, 2, 3); + + dataSet.groupBy("*", "invalidField"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> dataSet = env.fromElements(0, 1, 2, 3); + + dataSet.groupBy("invalidField"); + } + + @Test(expected = InvalidProgramException.class) + public void testGroupAtomicTypeWithInvalid3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<ArrayList<Integer>> dataSet = env.fromElements(new ArrayList<Integer>()); + + dataSet.groupBy("*"); + } + public static class CustomType implements Serializable { http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java index f1aadca..be964cc 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/JoinOperatorTest.java @@ -585,6 +585,78 @@ public class JoinOperatorTest { } ); } + + @Test + public void testJoinKeyAtomic1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 0); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("*").equalTo(0); + } + + @Test + public void testJoinKeyAtomic2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet<Integer> ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 0); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet<Integer> ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("*", "invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 0); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds2 = env.fromCollection(emptyTupleData, tupleTypeInfo); + + ds1.join(ds2).where("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic4() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple5<Integer, Long, String, Long, Integer>> ds1 = env.fromCollection(emptyTupleData, tupleTypeInfo); + DataSet<Integer> ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where(0).equalTo("invalidKey"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic5() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<ArrayList<Integer>> ds1 = env.fromElements(new ArrayList<Integer>()); + DataSet<Integer> ds2 = env.fromElements(0, 0, 0); + + ds1.join(ds2).where("*").equalTo("*"); + } + + @Test(expected = InvalidProgramException.class) + public void testJoinKeyInvalidAtomic6() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 0, 0); + DataSet<ArrayList<Integer>> ds2 = env.fromElements(new ArrayList<Integer>()); + + ds1.join(ds2).where("*").equalTo("*"); + } @Test public void testJoinProjection1() { http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java index 5fdf3dd..a685ff4 100644 --- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java @@ -41,7 +41,6 @@ import org.apache.flink.api.common.typeutils.TypeSerializer; import org.apache.flink.api.common.typeutils.TypeSerializerFactory; import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; import org.apache.flink.api.java.tuple.Tuple; -import org.apache.flink.api.java.typeutils.PojoTypeInfo; import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory; @@ -305,10 +304,6 @@ public class JavaApiPostPass implements OptimizerPostPass { } private static <T1 extends Tuple, T2 extends Tuple> TypePairComparatorFactory<T1,T2> createPairComparator(TypeInformation<?> typeInfo1, TypeInformation<?> typeInfo2) { - if (!(typeInfo1.isTupleType() || typeInfo1 instanceof PojoTypeInfo) && (typeInfo2.isTupleType() || typeInfo2 instanceof PojoTypeInfo)) { - throw new RuntimeException("The runtime currently supports only keyed binary operations (such as joins) on tuples and POJO types."); - } - // @SuppressWarnings("unchecked") // TupleTypeInfo<T1> info1 = (TupleTypeInfo<T1>) typeInfo1; // @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java index 99f568e..84c05d6 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/CoGroupITCase.java @@ -18,10 +18,6 @@ package org.apache.flink.test.javaApiOperators; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - import org.apache.flink.api.common.functions.CoGroupFunction; import org.apache.flink.api.common.functions.RichCoGroupFunction; import org.apache.flink.api.java.DataSet; @@ -47,6 +43,10 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + @RunWith(Parameterized.class) public class CoGroupITCase extends MultipleProgramsTestBase { @@ -488,6 +488,36 @@ public class CoGroupITCase extends MultipleProgramsTestBase { "-1,30000,Flink\n"; } + @Test + public void testCoGroupWithAtomicType1() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); + DataSet<Integer> ds2 = env.fromElements(0, 1, 2); + + DataSet<Tuple3<Integer, Long, String>> coGroupDs = ds1.coGroup(ds2).where(0).equalTo("*").with(new CoGroupAtomic1()); + + coGroupDs.writeAsText(resultPath); + env.execute(); + + expected = "(1,1,Hi)\n" + + "(2,2,Hello)"; + } + + @Test + public void testCoGroupWithAtomicType2() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds1 = env.fromElements(0, 1, 2); + DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); + + DataSet<Tuple3<Integer, Long, String>> coGroupDs = ds1.coGroup(ds2).where("*").equalTo(0).with(new CoGroupAtomic2()); + + coGroupDs.writeAsText(resultPath); + env.execute(); + + expected = "(1,1,Hi)\n" + + "(2,2,Hello)"; + } + public static class KeySelector1 implements KeySelector<POJO, Long> { private static final long serialVersionUID = 1L; @@ -719,4 +749,48 @@ public class CoGroupITCase extends MultipleProgramsTestBase { } } } + + public static class CoGroupAtomic1 implements CoGroupFunction<Tuple3<Integer, Long, String>, Integer, Tuple3<Integer, Long, String>> { + + private static final long serialVersionUID = 1L; + + @Override + public void coGroup(Iterable<Tuple3<Integer, Long, String>> first, Iterable<Integer> second, Collector<Tuple3<Integer, Long, String>> out) throws Exception { + List<Integer> ints = new ArrayList<Integer>(); + + for (Integer i : second) { + ints.add(i); + } + + for (Tuple3<Integer, Long, String> t : first) { + for (Integer i : ints) { + if (t.f0.equals(i)) { + out.collect(t); + } + } + } + } + } + + public static class CoGroupAtomic2 implements CoGroupFunction<Integer, Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> { + + private static final long serialVersionUID = 1L; + + @Override + public void coGroup(Iterable<Integer> first, Iterable<Tuple3<Integer, Long, String>> second, Collector<Tuple3<Integer, Long, String>> out) throws Exception { + List<Integer> ints = new ArrayList<Integer>(); + + for (Integer i : first) { + ints.add(i); + } + + for (Tuple3<Integer, Long, String> t : second) { + for (Integer i : ints) { + if (t.f0.equals(i)) { + out.collect(t); + } + } + } + } + } } http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java index 9eb9a37..cf6b529 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/GroupReduceITCase.java @@ -18,21 +18,20 @@ package org.apache.flink.test.javaApiOperators; -import java.util.Collection; -import java.util.Iterator; - import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.operators.Order; -import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.common.functions.RichGroupReduceFunction; import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple5; -import org.apache.flink.optimizer.Optimizer; import org.apache.flink.configuration.Configuration; +import org.apache.flink.optimizer.Optimizer; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CrazyNested; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets.CustomType; @@ -49,10 +48,13 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; import scala.math.BigInt; +import java.util.Collection; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; + @SuppressWarnings("serial") @RunWith(Parameterized.class) public class GroupReduceITCase extends MultipleProgramsTestBase { @@ -1063,6 +1065,26 @@ public class GroupReduceITCase extends MultipleProgramsTestBase { } + @Test + public void testGroupReduceWithAtomicValue() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Integer> ds = env.fromElements(1, 1, 2, 3, 4); + DataSet<Integer> reduceDs = ds.groupBy("*").reduceGroup(new GroupReduceFunction<Integer, Integer>() { + @Override + public void reduce(Iterable<Integer> values, Collector<Integer> out) throws Exception { + out.collect(values.iterator().next()); + } + }); + + reduceDs.writeAsText(resultPath); + env.execute(); + + expected = "1\n" + + "2\n" + + "3\n" + + "4"; + } + public static class GroupReducer8 implements GroupReduceFunction<CollectionDataSets.PojoWithCollection, String> { @Override public void reduce( http://git-wip-us.apache.org/repos/asf/flink/blob/30a74c76/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java index 0080fb1..fe436a3 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/JoinITCase.java @@ -663,6 +663,37 @@ public class JoinITCase extends MultipleProgramsTestBase { "((3,2,Hello world),(3,2,Hello world)),((3,2,Hello world),(3,2,Hello world))\n"; } + @Test + public void testJoinWithAtomicType1() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); + DataSet<Integer> ds2 = env.fromElements(1, 2); + + DataSet<Tuple2<Tuple3<Integer, Long, String>, Integer>> joinDs = ds1.join(ds2).where(0).equalTo("*"); + + joinDs.writeAsCsv(resultPath); + env.execute(); + + expected = "(1,1,Hi),1\n" + + "(2,2,Hello),2"; + } + + public void testJoinWithAtomicType2() throws Exception { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Integer> ds1 = env.fromElements(1, 2); + DataSet<Tuple3<Integer, Long, String>> ds2 = CollectionDataSets.getSmall3TupleDataSet(env); + + DataSet<Tuple2<Integer, Tuple3<Integer, Long, String>>> joinDs = ds1.join(ds2).where("*").equalTo(0); + + joinDs.writeAsCsv(resultPath); + env.execute(); + + expected = "1,(1,1,Hi)\n" + + "2,(2,2,Hello)"; + } + public static class T3T5FlatJoin implements FlatJoinFunction<Tuple3<Integer, Long, String>, Tuple5<Integer, Long, Integer, String, Long>, Tuple2<String, String>> { @Override