Repository: flink Updated Branches: refs/heads/master aec6ded5e -> c08bcf1e0
http://git-wip-us.apache.org/repos/asf/flink/blob/f5957ce3/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala index 7d26643..ca8bcd9 100644 --- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala +++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/PartitionITCase.scala @@ -66,6 +66,22 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( } @Test + def testRangePartitionByTupleField(): Unit = { + /* + * Test hash partition by tuple field + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + + val unique = ds.partitionByRange(1).mapPartition( _.map(_._2).toSet ) + + unique.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + + expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + } + + @Test def testHashPartitionByKeySelector(): Unit = { /* * Test hash partition by key selector @@ -80,6 +96,20 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( } @Test + def testRangePartitionByKeySelector(): Unit = { + /* + * Test hash partition by key selector + */ + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + val unique = ds.partitionByRange( _._2 ).mapPartition( _.map(_._2).toSet ) + + unique.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + } + + @Test def testForcedRebalancing(): Unit = { /* * Test forced rebalancing @@ -129,6 +159,24 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( } @Test + def testMapPartitionAfterRepartitionHasCorrectParallelism2(): Unit = { + // Verify that mapPartition operation after repartition picks up correct + // parallelism + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + env.setParallelism(1) + + val unique = ds.partitionByRange(1) + .setParallelism(4) + .mapPartition( _.map(_._2).toSet ) + + unique.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + + expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n" + } + + @Test def testMapAfterRepartitionHasCorrectParallelism(): Unit = { // Verify that map operation after repartition picks up correct // parallelism @@ -157,6 +205,35 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( } @Test + def testMapAfterRepartitionHasCorrectParallelism2(): Unit = { + // Verify that map operation after repartition picks up correct + // parallelism + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + env.setParallelism(1) + + val count = ds.partitionByRange(0).setParallelism(4).map( + new RichMapFunction[(Int, Long, String), Tuple1[Int]] { + var first = true + override def map(in: (Int, Long, String)): Tuple1[Int] = { + // only output one value with count 1 + if (first) { + first = false + Tuple1(1) + } else { + Tuple1(0) + } + } + }).sum(0) + + count.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + + expected = if (mode == TestExecutionMode.COLLECTION) "(1)\n" else "(4)\n" + } + + + @Test def testFilterAfterRepartitionHasCorrectParallelism(): Unit = { // Verify that filter operation after repartition picks up correct // parallelism @@ -186,7 +263,36 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( } @Test - def testPartitionNestedPojo(): Unit = { + def testFilterAfterRepartitionHasCorrectParallelism2(): Unit = { + // Verify that filter operation after repartition picks up correct + // parallelism + val env = ExecutionEnvironment.getExecutionEnvironment + val ds = CollectionDataSets.get3TupleDataSet(env) + env.setParallelism(1) + + val count = ds.partitionByRange(0).setParallelism(4).filter( + new RichFilterFunction[(Int, Long, String)] { + var first = true + override def filter(in: (Int, Long, String)): Boolean = { + // only output one value with count 1 + if (first) { + first = false + true + } else { + false + } + } + }) + .map( _ => Tuple1(1)).sum(0) + + count.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + + expected = if (mode == TestExecutionMode.COLLECTION) "(1)\n" else "(4)\n" + } + + @Test + def testHashPartitionNestedPojo(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment env.setParallelism(3) val ds = CollectionDataSets.getDuplicatePojoDataSet(env) @@ -199,4 +305,19 @@ class PartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase( env.execute() expected = "10000\n" + "20000\n" + "30000\n" } + + @Test + def testRangePartitionNestedPojo(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + env.setParallelism(3) + val ds = CollectionDataSets.getDuplicatePojoDataSet(env) + val uniqLongs = ds + .partitionByRange("nestedPojo.longNumber") + .setParallelism(4) + .mapPartition( _.map(_.nestedPojo.longNumber).toSet ) + + uniqLongs.writeAsText(resultPath, WriteMode.OVERWRITE) + env.execute() + expected = "10000\n" + "20000\n" + "30000\n" + } }