Repository: flink Updated Branches: refs/heads/master 59b237b5d -> 0a63797a6
[FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation. This closes #1585 Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0a63797a Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0a63797a Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0a63797a Branch: refs/heads/master Commit: 0a63797a6a5418b2363bca25bd77c33c217ff257 Parents: 572855d Author: Chiwan Park <chiwanp...@apache.org> Authored: Thu Feb 4 20:46:10 2016 +0900 Committer: Fabian Hueske <fhue...@apache.org> Committed: Wed Feb 10 11:51:26 2016 +0100 ---------------------------------------------------------------------- .../java/org/apache/flink/api/java/DataSet.java | 18 ++ .../java/operators/SortPartitionOperator.java | 174 +++++++++++++------ .../api/java/operator/SortPartitionTest.java | 82 +++++++++ .../org/apache/flink/api/scala/DataSet.scala | 25 +++ .../api/scala/PartitionSortedDataSet.scala | 22 ++- .../javaApiOperators/SortPartitionITCase.java | 61 +++++++ .../scala/operators/SortPartitionITCase.scala | 59 +++++++ 7 files changed, 385 insertions(+), 56 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java index bfb97f4..c315920 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java @@ -1381,6 +1381,24 @@ public abstract class DataSet<T> { return new SortPartitionOperator<>(this, field, order, Utils.getCallLocationName()); } + /** + * Locally sorts the partitions of the DataSet on the extracted key in the specified order. + * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector. + * + * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort + * the partitions by multiple values using KeySelector, the KeySelector must return a tuple + * consisting of the values. + * + * @param keyExtractor The KeySelector function which extracts the key values from the DataSet + * on which the DataSet is sorted. + * @param order The order in which the DataSet is sorted. + * @return The DataSet with sorted local partitions. + */ + public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) { + final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor, getType()); + return new SortPartitionOperator<>(this, new Keys.SelectorFunctionKeys<>(clean(keyExtractor), getType(), keyType), order, Utils.getCallLocationName()); + } + // -------------------------------------------------------------------------------------------- // Top-K // -------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java ---------------------------------------------------------------------- diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java index 354a0cd..7f30a30 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java @@ -26,9 +26,13 @@ import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.operators.Ordering; import org.apache.flink.api.common.operators.UnaryOperatorInformation; import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase; +import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; -import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; /** * This operator represents a DataSet with locally sorted partitions. @@ -38,27 +42,58 @@ import java.util.Arrays; @Public public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPartitionOperator<T>> { - private int[] sortKeyPositions; + private List<Keys<T>> keys; - private Order[] sortOrders; + private List<Order> orders; private final String sortLocationName; + private boolean useKeySelector; - public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) { + private SortPartitionOperator(DataSet<T> dataSet, String sortLocationName) { super(dataSet, dataSet.getType()); + + keys = new ArrayList<>(); + orders = new ArrayList<>(); this.sortLocationName = sortLocationName; + } + + + public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String sortLocationName) { + this(dataSet, sortLocationName); + this.useKeySelector = false; + + ensureSortableKey(sortField); - int[] flatOrderKeys = getFlatFields(sortField); - this.appendSorting(flatOrderKeys, sortOrder); + keys.add(new Keys.ExpressionKeys<>(sortField, getType())); + orders.add(sortOrder); } public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrder, String sortLocationName) { - super(dataSet, dataSet.getType()); - this.sortLocationName = sortLocationName; + this(dataSet, sortLocationName); + this.useKeySelector = false; + + ensureSortableKey(sortField); + + keys.add(new Keys.ExpressionKeys<>(sortField, getType())); + orders.add(sortOrder); + } + + public <K> SortPartitionOperator(DataSet<T> dataSet, Keys.SelectorFunctionKeys<T, K> sortKey, Order sortOrder, String sortLocationName) { + this(dataSet, sortLocationName); + this.useKeySelector = true; + + ensureSortableKey(sortKey); - int[] flatOrderKeys = getFlatFields(sortField); - this.appendSorting(flatOrderKeys, sortOrder); + keys.add(sortKey); + orders.add(sortOrder); + } + + /** + * Returns whether using key selector or not. + */ + public boolean useKeySelector() { + return useKeySelector; } /** @@ -70,9 +105,14 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart * @return The DataSet with sorted local partitions. */ public SortPartitionOperator<T> sortPartition(int field, Order order) { + if (useKeySelector) { + throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector"); + } + + ensureSortableKey(field); + keys.add(new Keys.ExpressionKeys<>(field, getType())); + orders.add(order); - int[] flatOrderKeys = getFlatFields(field); - this.appendSorting(flatOrderKeys, order); return this; } @@ -81,58 +121,41 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart * local partition sorting of the DataSet. * * @param field The field expression referring to the field of the additional sort order of - * the local partition sorting. - * @param order The order of the additional sort order of the local partition sorting. + * the local partition sorting. + * @param order The order of the additional sort order of the local partition sorting. * @return The DataSet with sorted local partitions. */ public SortPartitionOperator<T> sortPartition(String field, Order order) { - int[] flatOrderKeys = getFlatFields(field); - this.appendSorting(flatOrderKeys, order); + if (useKeySelector) { + throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector"); + } + + ensureSortableKey(field); + keys.add(new Keys.ExpressionKeys<>(field, getType())); + orders.add(order); + return this; } - // -------------------------------------------------------------------------------------------- - // Key Extraction - // -------------------------------------------------------------------------------------------- - - private int[] getFlatFields(int field) { + public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor, Order order) { + throw new InvalidProgramException("KeySelector cannot be chained."); + } - if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) { + private void ensureSortableKey(int field) throws InvalidProgramException { + if (!Keys.ExpressionKeys.isSortKey(field, getType())) { throw new InvalidProgramException("Selected sort key is not a sortable type"); } - - Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, super.getType()); - return ek.computeLogicalKeyPositions(); } - private int[] getFlatFields(String fields) { - - if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) { + private void ensureSortableKey(String field) throws InvalidProgramException { + if (!Keys.ExpressionKeys.isSortKey(field, getType())) { throw new InvalidProgramException("Selected sort key is not a sortable type"); } - - Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, super.getType()); - return ek.computeLogicalKeyPositions(); } - private void appendSorting(int[] flatOrderFields, Order order) { - - if(this.sortKeyPositions == null) { - // set sorting info - this.sortKeyPositions = flatOrderFields; - this.sortOrders = new Order[flatOrderFields.length]; - Arrays.fill(this.sortOrders, order); - } else { - // append sorting info to exising info - int oldLength = this.sortKeyPositions.length; - int newLength = oldLength + flatOrderFields.length; - this.sortKeyPositions = Arrays.copyOf(this.sortKeyPositions, newLength); - this.sortOrders = Arrays.copyOf(this.sortOrders, newLength); - - for(int i=0; i<flatOrderFields.length; i++) { - this.sortKeyPositions[oldLength+i] = flatOrderFields[i]; - this.sortOrders[oldLength+i] = order; - } + private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> sortKey) { + if (!sortKey.getKeyType().isSortKeyType()) { + throw new InvalidProgramException("Selected sort key is not a sortable type"); } } @@ -144,16 +167,33 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart String name = "Sort at " + sortLocationName; + if (useKeySelector) { + return translateToDataFlowWithKeyExtractor(input, (Keys.SelectorFunctionKeys<T, ?>) keys.get(0), orders.get(0), name); + } + + // flatten sort key positions + List<Integer> allKeyPositions = new ArrayList<>(); + List<Order> allOrders = new ArrayList<>(); + for (int i = 0, length = keys.size(); i < length; i++) { + int[] sortKeyPositions = keys.get(i).computeLogicalKeyPositions(); + Order order = orders.get(i); + + for (int sortKeyPosition : sortKeyPositions) { + allKeyPositions.add(sortKeyPosition); + allOrders.add(order); + } + } + Ordering partitionOrdering = new Ordering(); - for (int i = 0; i < this.sortKeyPositions.length; i++) { - partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, this.sortOrders[i]); + for (int i = 0, length = allKeyPositions.size(); i < length; i++) { + partitionOrdering.appendOrdering(allKeyPositions.get(i), null, allOrders.get(i)); } // distinguish between partition types UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<>(getType(), getType()); - SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name); + SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name); noop.setInput(input); - if(this.getParallelism() < 0) { + if (this.getParallelism() < 0) { // use parallelism of input if not explicitly specified noop.setParallelism(input.getParallelism()); } else { @@ -165,4 +205,32 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPart } + private <K> org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?> translateToDataFlowWithKeyExtractor( + Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order order, String name) { + TypeInformation<Tuple2<K, T>> typeInfoWithKey = KeyFunctions.createTypeWithKey(keys); + Keys.ExpressionKeys<Tuple2<K, T>> newKey = new Keys.ExpressionKeys<>(0, typeInfoWithKey); + + Operator<Tuple2<K, T>> keyedInput = KeyFunctions.appendKeyExtractor(input, keys); + + int[] sortKeyPositions = newKey.computeLogicalKeyPositions(); + Ordering partitionOrdering = new Ordering(); + for (int keyPosition : sortKeyPositions) { + partitionOrdering.appendOrdering(keyPosition, null, order); + } + + // distinguish between partition types + UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey); + SortPartitionOperatorBase<Tuple2<K, T>> noop = new SortPartitionOperatorBase<>(operatorInfo, partitionOrdering, name); + noop.setInput(keyedInput); + if (this.getParallelism() < 0) { + // use parallelism of input if not explicitly specified + noop.setParallelism(input.getParallelism()); + } else { + // use explicitly specified parallelism + noop.setParallelism(this.getParallelism()); + } + + return KeyFunctions.appendKeyRemover(noop, keys); + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java ---------------------------------------------------------------------- diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java index a4e2bbc..3540e6a 100644 --- a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java +++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java @@ -169,6 +169,88 @@ public class SortPartitionTest { tupleDs.sortPartition("f3", Order.ASCENDING); } + @Test + public void testSortPartitionWithKeySelector1() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo); + + // should work + try { + tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Integer>() { + @Override + public Integer getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception { + return value.f0; + } + }, Order.ASCENDING); + } catch (Exception e) { + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testSortPartitionWithKeySelector2() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo); + + // must not work + tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, Long[]>() { + @Override + public Long[] getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception { + return value.f3; + } + }, Order.ASCENDING); + } + + @Test(expected = InvalidProgramException.class) + public void testSortPartitionWithKeySelector3() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo); + + // must not work + tupleDs + .sortPartition("f1", Order.ASCENDING) + .sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() { + @Override + public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception { + return value.f2; + } + }, Order.ASCENDING); + } + + @Test + public void testSortPartitionWithKeySelector4() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo); + + // should work + try { + tupleDs.sortPartition(new KeySelector<Tuple4<Integer,Long,CustomType,Long[]>, Tuple2<Integer, Long>>() { + @Override + public Tuple2<Integer, Long> getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception { + return new Tuple2<>(value.f0, value.f1); + } + }, Order.ASCENDING); + } catch (Exception e) { + Assert.fail(); + } + } + + @Test(expected = InvalidProgramException.class) + public void testSortPartitionWithKeySelector5() { + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData, tupleWithCustomInfo); + + // must not work + tupleDs + .sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>() { + @Override + public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception { + return value.f2; + } + }, Order.ASCENDING) + .sortPartition("f1", Order.ASCENDING); + } + public static class CustomType implements Serializable { public static class Nest { http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala index e47bc42..5735b32 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala @@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) { new SortPartitionOperator[T](javaSet, field, order, getCallLocationName())) } + /** + * Locally sorts the partitions of the DataSet on the extracted key in the specified order. + * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector. + * + * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort + * the partitions by multiple values using KeySelector, the KeySelector must return a tuple + * consisting of the values. + */ + def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={ + val keyExtractor = new KeySelector[T, K] { + val cleanFun = clean(fun) + def getKey(in: T) = cleanFun(in) + } + + val keyType = implicitly[TypeInformation[K]] + new PartitionSortedDataSet[T]( + new SortPartitionOperator[T](javaSet, + new Keys.SelectorFunctionKeys[T, K]( + keyExtractor, + javaSet.getType, + keyType), + order, + getCallLocationName())) + } + // -------------------------------------------------------------------------------------------- // Result writing // -------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala ---------------------------------------------------------------------- diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala index c924a76..a402dd9 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala @@ -18,7 +18,9 @@ package org.apache.flink.api.scala import org.apache.flink.annotation.Public +import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.operators.SortPartitionOperator import scala.reflect.ClassTag @@ -37,16 +39,30 @@ class PartitionSortedDataSet[T: ClassTag](set: SortPartitionOperator[T]) * Appends the given field and order to the sort-partition operator. */ override def sortPartition(field: Int, order: Order): DataSet[T] = { + if (set.useKeySelector()) { + throw new InvalidProgramException("Expression keys cannot be appended after selector " + + "function keys") + } + this.set.sortPartition(field, order) this } -/** - * Appends the given field and order to the sort-partition operator. - */ + /** + * Appends the given field and order to the sort-partition operator. + */ override def sortPartition(field: String, order: Order): DataSet[T] = { + if (set.useKeySelector()) { + throw new InvalidProgramException("Expression keys cannot be appended after selector " + + "function keys") + } + this.set.sortPartition(field, order) this } + override def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] = { + throw new InvalidProgramException("KeySelector cannot be chained.") + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java index 2423420..c7f07f6 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java @@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.java.DataSet; import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.tuple.Tuple1; import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple3; @@ -197,6 +198,58 @@ public class SortPartitionITCase extends MultipleProgramsTestBase { compareResultAsText(result, expected); } + @Test + public void testSortPartitionWithKeySelector1() throws Exception { + /* + * Test sort partition on an extracted key + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + List<Tuple1<Boolean>> result = ds + .map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize input + .sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Long>() { + @Override + public Long getKey(Tuple3<Integer, Long, String> value) throws Exception { + return value.f1; + } + }, Order.ASCENDING) + .mapPartition(new OrderCheckMapper<>(new Tuple3AscendingChecker())) + .distinct().collect(); + + String expected = "(true)\n"; + + compareResultAsText(result, expected); + } + + @Test + public void testSortPartitionWithKeySelector2() throws Exception { + /* + * Test sort partition on an extracted key + */ + + final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setParallelism(4); + + DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env); + List<Tuple1<Boolean>> result = ds + .map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize input + .sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Tuple2<Integer, Long>>() { + @Override + public Tuple2<Integer, Long> getKey(Tuple3<Integer, Long, String> value) throws Exception { + return new Tuple2<>(value.f0, value.f1); + } + }, Order.DESCENDING) + .mapPartition(new OrderCheckMapper<>(new Tuple3Checker())) + .distinct().collect(); + + String expected = "(true)\n"; + + compareResultAsText(result, expected); + } + public interface OrderChecker<T> extends Serializable { boolean inOrder(T t1, T t2); } @@ -210,6 +263,14 @@ public class SortPartitionITCase extends MultipleProgramsTestBase { } @SuppressWarnings("serial") + public static class Tuple3AscendingChecker implements OrderChecker<Tuple3<Integer, Long, String>> { + @Override + public boolean inOrder(Tuple3<Integer, Long, String> t1, Tuple3<Integer, Long, String> t2) { + return t1.f1 <= t2.f1; + } + } + + @SuppressWarnings("serial") public static class Tuple5Checker implements OrderChecker<Tuple5<Integer, Long, Integer, String, Long>> { @Override public boolean inOrder(Tuple5<Integer, Long, Integer, String, Long> t1, http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala index 3f67063..cda8f4f 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.flink.api.common.functions.MapPartitionFunction import org.apache.flink.api.common.operators.Order +import org.apache.flink.api.common.InvalidProgramException import org.apache.flink.api.scala._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode @@ -166,6 +167,58 @@ class SortPartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestB TestBaseUtils.compareResultAsText(result.asJava, expected) } + @Test + def testSortPartitionWithKeySelector1(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + env.setParallelism(4) + val ds = CollectionDataSets.get3TupleDataSet(env) + + val result = ds + .map { x => x }.setParallelism(4) + .sortPartition(_._2, Order.ASCENDING) + .mapPartition(new OrderCheckMapper(new Tuple3AscendingChecker)) + .distinct() + .collect() + + val expected: String = "(true)\n" + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + + @Test + def testSortPartitionWithKeySelector2(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + env.setParallelism(4) + val ds = CollectionDataSets.get3TupleDataSet(env) + + val result = ds + .map { x => x }.setParallelism(4) + .sortPartition(x => (x._2, x._1), Order.DESCENDING) + .mapPartition(new OrderCheckMapper(new Tuple3Checker)) + .distinct() + .collect() + + val expected: String = "(true)\n" + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + + @Test(expected = classOf[InvalidProgramException]) + def testSortPartitionWithKeySelector3(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + env.setParallelism(4) + val ds = CollectionDataSets.get3TupleDataSet(env) + + val result = ds + .map { x => x }.setParallelism(4) + .sortPartition(x => (x._2, x._1), Order.DESCENDING) + .sortPartition(0, Order.DESCENDING) + .mapPartition(new OrderCheckMapper(new Tuple3Checker)) + .distinct() + .collect() + + val expected: String = "(true)\n" + TestBaseUtils.compareResultAsText(result.asJava, expected) + } + } trait OrderChecker[T] extends Serializable { @@ -178,6 +231,12 @@ class Tuple3Checker extends OrderChecker[(Int, Long, String)] { } } +class Tuple3AscendingChecker extends OrderChecker[(Int, Long, String)] { + def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = { + t1._2 <= t2._2 + } +} + class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] { def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, Long)): Boolean = { t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3