http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java new file mode 100644 index 0000000..0273659 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/BinaryCustomPartitioningCompatibilityTest.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.plantranslate.JobGraphGenerator; +import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial","unchecked"}) +public class BinaryCustomPartitioningCompatibilityTest extends CompilerTestBase { + + @Test + public void testCompatiblePartitioningJoin() { + try { + final Partitioner<Long> partitioner = new Partitioner<Long>() { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + }; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + input1.partitionCustom(partitioner, 1) + .join(input2.partitionCustom(partitioner, 0)) + .where(1).equalTo(0) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode partitioner1 = (SingleInputPlanNode) join.getInput1().getSource(); + SingleInputPlanNode partitioner2 = (SingleInputPlanNode) join.getInput2().getSource(); + + assertEquals(ShipStrategyType.FORWARD, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, join.getInput2().getShipStrategy()); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner1.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner2.getInput().getShipStrategy()); + assertEquals(partitioner, partitioner1.getInput().getPartitioner()); + assertEquals(partitioner, partitioner2.getInput().getPartitioner()); + + new JobGraphGenerator().compileJobGraph(op); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCompatiblePartitioningCoGroup() { + try { + final Partitioner<Long> partitioner = new Partitioner<Long>() { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + }; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + input1.partitionCustom(partitioner, 1) + .coGroup(input2.partitionCustom(partitioner, 0)) + .where(1).equalTo(0) + .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Long>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode partitioner1 = (SingleInputPlanNode) coGroup.getInput1().getSource(); + SingleInputPlanNode partitioner2 = (SingleInputPlanNode) coGroup.getInput2().getSource(); + + assertEquals(ShipStrategyType.FORWARD, coGroup.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, coGroup.getInput2().getShipStrategy()); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner1.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner2.getInput().getShipStrategy()); + assertEquals(partitioner, partitioner1.getInput().getPartitioner()); + assertEquals(partitioner, partitioner2.getInput().getPartitioner()); + + new JobGraphGenerator().compileJobGraph(op); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } +}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java new file mode 100644 index 0000000..08f7388 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CoGroupCustomPartitioningTest.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.optimizer.testfunctions.IdentityMapper; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class CoGroupCustomPartitioningTest extends CompilerTestBase { + + @Test + public void testCoGroupWithTuples() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + input1 + .coGroup(input2) + .where(1).equalTo(0) + .withPartitioner(partitioner) + .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple3<Long, Long, Long>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCoGroupWithTuplesWrongType() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + try { + input1 + .coGroup(input2) + .where(1).equalTo(0) + .withPartitioner(partitioner); + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCoGroupWithPojos() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + input1 + .coGroup(input2) + .where("b").equalTo("a") + .withPartitioner(partitioner) + .with(new DummyCoGroupFunction<Pojo2, Pojo3>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCoGroupWithPojosWrongType() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + try { + input1 + .coGroup(input2) + .where("a").equalTo("b") + .withPartitioner(partitioner); + + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCoGroupWithKeySelectors() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + input1 + .coGroup(input2) + .where(new Pojo2KeySelector()).equalTo(new Pojo3KeySelector()) + .withPartitioner(partitioner) + .with(new DummyCoGroupFunction<Pojo2, Pojo3>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCoGroupWithKeySelectorsWrongType() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + try { + input1 + .coGroup(input2) + .where(new Pojo2KeySelector()).equalTo(new Pojo3KeySelector()) + .withPartitioner(partitioner); + + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIncompatibleHashAndCustomPartitioning() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Long, Long, Long>> input = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + DataSet<Tuple3<Long, Long, Long>> partitioned = input + .partitionCustom(new Partitioner<Long>() { + @Override + public int partition(Long key, int numPartitions) { return 0; } + }, 0) + .map(new IdentityMapper<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1", "2"); + + + DataSet<Tuple3<Long, Long, Long>> grouped = partitioned + .distinct(0, 1) + .groupBy(1) + .sortGroup(0, Order.ASCENDING) + .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1"); + + grouped + .coGroup(partitioned).where(0).equalTo(0) + .with(new DummyCoGroupFunction<Tuple3<Long,Long,Long>, Tuple3<Long,Long,Long>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_HASH, coGroup.getInput1().getShipStrategy()); + assertTrue(coGroup.getInput2().getShipStrategy() == ShipStrategyType.PARTITION_HASH || + coGroup.getInput2().getShipStrategy() == ShipStrategyType.FORWARD); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } + + public static class Pojo2 { + public int a; + public int b; + } + + public static class Pojo3 { + public int a; + public int b; + public int c; + } + + private static class Pojo2KeySelector implements KeySelector<Pojo2, Integer> { + @Override + public Integer getKey(Pojo2 value) { + return value.a; + } + } + + private static class Pojo3KeySelector implements KeySelector<Pojo3, Integer> { + @Override + public Integer getKey(Pojo3 value) { + return value.b; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java new file mode 100644 index 0000000..9fd676f --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningGlobalOptimizationTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class CustomPartitioningGlobalOptimizationTest extends CompilerTestBase { + + @Test + public void testJoinReduceCombination() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + DataSet<Tuple3<Long, Long, Long>> joined = input1.join(input2) + .where(1).equalTo(0) + .projectFirst(0, 1) + .<Tuple3<Long, Long, Long>>projectSecond(2) + .withPartitioner(partitioner); + + joined.groupBy(1).withPartitioner(partitioner) + .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + + assertTrue("Reduce is not chained, property reuse does not happen", + reducer.getInput().getSource() instanceof DualInputPlanNode); + + DualInputPlanNode join = (DualInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + + assertEquals(ShipStrategyType.FORWARD, reducer.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java new file mode 100644 index 0000000..d397ea2 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/CustomPartitioningTest.java @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.IdentityPartitionerMapper; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class CustomPartitioningTest extends CompilerTestBase { + + @Test + public void testPartitionTuples() { + try { + final Partitioner<Integer> part = new TestPartitionerInt(); + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0)) + .rebalance(); + + data + .partitionCustom(part, 0) + .mapPartition(new IdentityPartitionerMapper<Tuple2<Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource(); + SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(parallelism, sink.getParallelism()); + + assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy()); + assertEquals(parallelism, mapper.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy()); + assertEquals(part, partitioner.getInput().getPartitioner()); + assertEquals(parallelism, partitioner.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy()); + assertEquals(parallelism, balancer.getParallelism()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPartitionTuplesInvalidType() { + try { + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer,Integer>(0, 0)) + .rebalance(); + + try { + data + .partitionCustom(new TestPartitionerLong(), 0); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPartitionPojo() { + try { + final Partitioner<Integer> part = new TestPartitionerInt(); + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Pojo> data = env.fromElements(new Pojo()) + .rebalance(); + + data + .partitionCustom(part, "a") + .mapPartition(new IdentityPartitionerMapper<Pojo>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode partitioner = (SingleInputPlanNode) mapper.getInput().getSource(); + SingleInputPlanNode balancer = (SingleInputPlanNode) partitioner.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(parallelism, sink.getParallelism()); + + assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy()); + assertEquals(parallelism, mapper.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy()); + assertEquals(part, partitioner.getInput().getPartitioner()); + assertEquals(parallelism, partitioner.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy()); + assertEquals(parallelism, balancer.getParallelism()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPartitionPojoInvalidType() { + try { + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Pojo> data = env.fromElements(new Pojo()) + .rebalance(); + + try { + data + .partitionCustom(new TestPartitionerLong(), "a"); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPartitionKeySelector() { + try { + final Partitioner<Integer> part = new TestPartitionerInt(); + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Pojo> data = env.fromElements(new Pojo()) + .rebalance(); + + data + .partitionCustom(part, new TestKeySelectorInt<Pojo>()) + .mapPartition(new IdentityPartitionerMapper<Pojo>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode mapper = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode keyRemover = (SingleInputPlanNode) mapper.getInput().getSource(); + SingleInputPlanNode partitioner = (SingleInputPlanNode) keyRemover.getInput().getSource(); + SingleInputPlanNode keyExtractor = (SingleInputPlanNode) partitioner.getInput().getSource(); + SingleInputPlanNode balancer = (SingleInputPlanNode) keyExtractor.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(parallelism, sink.getParallelism()); + + assertEquals(ShipStrategyType.FORWARD, mapper.getInput().getShipStrategy()); + assertEquals(parallelism, mapper.getParallelism()); + + assertEquals(ShipStrategyType.FORWARD, keyRemover.getInput().getShipStrategy()); + assertEquals(parallelism, keyRemover.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitioner.getInput().getShipStrategy()); + assertEquals(part, partitioner.getInput().getPartitioner()); + assertEquals(parallelism, partitioner.getParallelism()); + + assertEquals(ShipStrategyType.FORWARD, keyExtractor.getInput().getShipStrategy()); + assertEquals(parallelism, keyExtractor.getParallelism()); + + assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, balancer.getInput().getShipStrategy()); + assertEquals(parallelism, balancer.getParallelism()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testPartitionKeySelectorInvalidType() { + try { + final Partitioner<Integer> part = (Partitioner<Integer>) (Partitioner<?>) new TestPartitionerLong(); + final int parallelism = 4; + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.setDegreeOfParallelism(parallelism); + + DataSet<Pojo> data = env.fromElements(new Pojo()) + .rebalance(); + + try { + data + .partitionCustom(part, new TestKeySelectorInt<Pojo>()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + public static class Pojo { + public int a; + public int b; + } + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } + + private static class TestKeySelectorInt<T> implements KeySelector<T, Integer> { + @Override + public Integer getKey(T value) { + return null; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java new file mode 100644 index 0000000..360487b --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingKeySelectorTranslationTest.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyReducer; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class GroupingKeySelectorTranslationTest extends CompilerTestBase { + + @Test + public void testCustomPartitioningKeySelectorReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>()) + .withPartitioner(new TestPartitionerInt()) + .reduce(new DummyReducer<Tuple2<Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode keyRemovingMapper = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode reducer = (SingleInputPlanNode) keyRemovingMapper.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, keyRemovingMapper.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningKeySelectorGroupReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(new TestKeySelector<Tuple2<Integer,Integer>>()) + .withPartitioner(new TestPartitionerInt()) + .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningKeySelectorGroupReduceSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>()) + .withPartitioner(new TestPartitionerInt()) + .sortGroup(new TestKeySelector<Tuple3<Integer, Integer, Integer>>(), Order.ASCENDING) + .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningKeySelectorInvalidType() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + try { + data + .groupBy(new TestKeySelector<Tuple2<Integer,Integer>>()) + .withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningKeySelectorInvalidTypeSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + try { + data + .groupBy(new TestKeySelector<Tuple3<Integer,Integer,Integer>>()) + .sortGroup(1, Order.ASCENDING) + .withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleRejectCompositeKey() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + try { + data + .groupBy(new TestBinaryKeySelector<Tuple3<Integer,Integer,Integer>>()) + .withPartitioner(new TestPartitionerInt()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } + + private static class TestKeySelector<T extends Tuple> implements KeySelector<T, Integer> { + @Override + public Integer getKey(T value) { + return value.getField(0); + } + } + + private static class TestBinaryKeySelector<T extends Tuple> implements KeySelector<T, Tuple2<Integer, Integer>> { + @Override + public Tuple2<Integer, Integer> getKey(T value) { + return new Tuple2<Integer, Integer>(value.<Integer>getField(0), value.<Integer>getField(1)); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java new file mode 100644 index 0000000..8cd4809 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingPojoTranslationTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyReducer; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings("serial") +public class GroupingPojoTranslationTest extends CompilerTestBase { + + @Test + public void testCustomPartitioningTupleReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> data = env.fromElements(new Pojo2()) + .rebalance().setParallelism(4); + + data.groupBy("a").withPartitioner(new TestPartitionerInt()) + .reduce(new DummyReducer<Pojo2>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> data = env.fromElements(new Pojo2()) + .rebalance().setParallelism(4); + + data.groupBy("a").withPartitioner(new TestPartitionerInt()) + .reduceGroup(new IdentityGroupReducer<Pojo2>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduceSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo3> data = env.fromElements(new Pojo3()) + .rebalance().setParallelism(4); + + data.groupBy("a").withPartitioner(new TestPartitionerInt()) + .sortGroup("b", Order.ASCENDING) + .reduceGroup(new IdentityGroupReducer<Pojo3>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduceSorted2() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo4> data = env.fromElements(new Pojo4()) + .rebalance().setParallelism(4); + + data.groupBy("a").withPartitioner(new TestPartitionerInt()) + .sortGroup("b", Order.ASCENDING) + .sortGroup("c", Order.DESCENDING) + .reduceGroup(new IdentityGroupReducer<Pojo4>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleInvalidType() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> data = env.fromElements(new Pojo2()) + .rebalance().setParallelism(4); + + try { + data.groupBy("a").withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleInvalidTypeSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo3> data = env.fromElements(new Pojo3()) + .rebalance().setParallelism(4); + + try { + data.groupBy("a") + .sortGroup("b", Order.ASCENDING) + .withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleRejectCompositeKey() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> data = env.fromElements(new Pojo2()) + .rebalance().setParallelism(4); + + try { + data.groupBy("a", "b") + .withPartitioner(new TestPartitionerInt()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + public static class Pojo2 { + public int a; + public int b; + + } + + public static class Pojo3 { + public int a; + public int b; + public int c; + } + + public static class Pojo4 { + public int a; + public int b; + public int c; + public int d; + } + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java new file mode 100644 index 0000000..779b8e5 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/GroupingTupleTranslationTest.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyReducer; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class GroupingTupleTranslationTest extends CompilerTestBase { + + @Test + public void testCustomPartitioningTupleAgg() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(0).withPartitioner(new TestPartitionerInt()) + .sum(1) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(0).withPartitioner(new TestPartitionerInt()) + .reduce(new DummyReducer<Tuple2<Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduce() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(0).withPartitioner(new TestPartitionerInt()) + .reduceGroup(new IdentityGroupReducer<Tuple2<Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduceSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(0).withPartitioner(new TestPartitionerInt()) + .sortGroup(1, Order.ASCENDING) + .reduceGroup(new IdentityGroupReducer<Tuple3<Integer,Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleGroupReduceSorted2() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple4<Integer,Integer,Integer, Integer>> data = env.fromElements(new Tuple4<Integer,Integer,Integer,Integer>(0, 0, 0, 0)) + .rebalance().setParallelism(4); + + data.groupBy(0).withPartitioner(new TestPartitionerInt()) + .sortGroup(1, Order.ASCENDING) + .sortGroup(2, Order.DESCENDING) + .reduceGroup(new IdentityGroupReducer<Tuple4<Integer,Integer,Integer,Integer>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + SingleInputPlanNode reducer = (SingleInputPlanNode) sink.getInput().getSource(); + SingleInputPlanNode combiner = (SingleInputPlanNode) reducer.getInput().getSource(); + + assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, reducer.getInput().getShipStrategy()); + assertEquals(ShipStrategyType.FORWARD, combiner.getInput().getShipStrategy()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleInvalidType() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Integer, Integer>> data = env.fromElements(new Tuple2<Integer, Integer>(0, 0)) + .rebalance().setParallelism(4); + + try { + data.groupBy(0).withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleInvalidTypeSorted() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + try { + data.groupBy(0) + .sortGroup(1, Order.ASCENDING) + .withPartitioner(new TestPartitionerLong()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testCustomPartitioningTupleRejectCompositeKey() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Integer, Integer, Integer>> data = env.fromElements(new Tuple3<Integer, Integer, Integer>(0, 0, 0)) + .rebalance().setParallelism(4); + + try { + data.groupBy(0, 1) + .withPartitioner(new TestPartitionerInt()); + fail("Should throw an exception"); + } + catch (InvalidProgramException e) {} + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java new file mode 100644 index 0000000..eae40cf --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/custompartition/JoinCustomPartitioningTest.java @@ -0,0 +1,309 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.custompartition; + +import static org.junit.Assert.*; + +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.Plan; +import org.apache.flink.api.common.functions.Partitioner; +import org.apache.flink.api.common.operators.Order; +import org.apache.flink.api.common.operators.base.JoinOperatorBase.JoinHint; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyFlatJoinFunction; +import org.apache.flink.optimizer.testfunctions.IdentityGroupReducer; +import org.apache.flink.optimizer.testfunctions.IdentityMapper; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.junit.Test; + +@SuppressWarnings({"serial", "unchecked"}) +public class JoinCustomPartitioningTest extends CompilerTestBase { + + @Test + public void testJoinWithTuples() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0).withPartitioner(partitioner) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinWithTuplesWrongType() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple2<Long, Long>> input1 = env.fromElements(new Tuple2<Long, Long>(0L, 0L)); + DataSet<Tuple3<Long, Long, Long>> input2 = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + try { + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST).where(1).equalTo(0) + .withPartitioner(partitioner); + + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinWithPojos() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST) + .where("b").equalTo("a").withPartitioner(partitioner) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinWithPojosWrongType() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + try { + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST) + .where("a").equalTo("b") + .withPartitioner(partitioner); + + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinWithKeySelectors() { + try { + final Partitioner<Integer> partitioner = new TestPartitionerInt(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST) + .where(new Pojo2KeySelector()) + .equalTo(new Pojo3KeySelector()) + .withPartitioner(partitioner) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode join = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput1().getShipStrategy()); + assertEquals(ShipStrategyType.PARTITION_CUSTOM, join.getInput2().getShipStrategy()); + assertEquals(partitioner, join.getInput1().getPartitioner()); + assertEquals(partitioner, join.getInput2().getPartitioner()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testJoinWithKeySelectorsWrongType() { + try { + final Partitioner<Long> partitioner = new TestPartitionerLong(); + + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Pojo2> input1 = env.fromElements(new Pojo2()); + DataSet<Pojo3> input2 = env.fromElements(new Pojo3()); + + try { + input1 + .join(input2, JoinHint.REPARTITION_HASH_FIRST) + .where(new Pojo2KeySelector()) + .equalTo(new Pojo3KeySelector()) + .withPartitioner(partitioner); + + fail("should throw an exception"); + } + catch (InvalidProgramException e) { + // expected + } + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIncompatibleHashAndCustomPartitioning() { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + + DataSet<Tuple3<Long, Long, Long>> input = env.fromElements(new Tuple3<Long, Long, Long>(0L, 0L, 0L)); + + DataSet<Tuple3<Long, Long, Long>> partitioned = input + .partitionCustom(new Partitioner<Long>() { + @Override + public int partition(Long key, int numPartitions) { return 0; } + }, 0) + .map(new IdentityMapper<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1", "2"); + + + DataSet<Tuple3<Long, Long, Long>> grouped = partitioned + .distinct(0, 1) + .groupBy(1) + .sortGroup(0, Order.ASCENDING) + .reduceGroup(new IdentityGroupReducer<Tuple3<Long,Long,Long>>()).withForwardedFields("0", "1"); + + grouped + .join(partitioned, JoinHint.REPARTITION_HASH_FIRST).where(0).equalTo(0) + .with(new DummyFlatJoinFunction<Tuple3<Long,Long,Long>>()) + .print(); + + Plan p = env.createProgramPlan(); + OptimizedPlan op = compileNoStats(p); + + SinkPlanNode sink = op.getDataSinks().iterator().next(); + DualInputPlanNode coGroup = (DualInputPlanNode) sink.getInput().getSource(); + + assertEquals(ShipStrategyType.PARTITION_HASH, coGroup.getInput1().getShipStrategy()); + assertTrue(coGroup.getInput2().getShipStrategy() == ShipStrategyType.PARTITION_HASH || + coGroup.getInput2().getShipStrategy() == ShipStrategyType.FORWARD); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + // -------------------------------------------------------------------------------------------- + + private static class TestPartitionerInt implements Partitioner<Integer> { + @Override + public int partition(Integer key, int numPartitions) { + return 0; + } + } + + private static class TestPartitionerLong implements Partitioner<Long> { + @Override + public int partition(Long key, int numPartitions) { + return 0; + } + } + + public static class Pojo2 { + public int a; + public int b; + } + + public static class Pojo3 { + public int a; + public int b; + public int c; + } + + private static class Pojo2KeySelector implements KeySelector<Pojo2, Integer> { + @Override + public Integer getKey(Pojo2 value) { + return value.a; + } + } + + private static class Pojo3KeySelector implements KeySelector<Pojo3, Integer> { + @Override + public Integer getKey(Pojo3 value) { + return value.b; + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java new file mode 100644 index 0000000..cb4bd78 --- /dev/null +++ b/flink-optimizer/src/test/java/org/apache/flink/optimizer/dataexchange/DataExchangeModeClosedBranchingTest.java @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.dataexchange; + +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.io.DiscardingOutputFormat; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.optimizer.CompilerTestBase; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.testfunctions.DummyCoGroupFunction; +import org.apache.flink.optimizer.testfunctions.DummyFlatJoinFunction; +import org.apache.flink.optimizer.testfunctions.IdentityFlatMapper; +import org.apache.flink.optimizer.testfunctions.SelectOneReducer; +import org.apache.flink.optimizer.testfunctions.Top1GroupReducer; +import org.apache.flink.runtime.io.network.DataExchangeMode; +import org.junit.Test; + +import java.util.Collection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * This test checks the correct assignment of the DataExchangeMode to + * connections for programs that branch, and re-join those branches. + * + * <pre> + * /-> (sink) + * / + * /-> (reduce) -+ /-> (flatmap) -> (sink) + * / \ / + * (source) -> (map) - (join) -+-----\ + * \ / \ + * \-> (filter) -+ \ + * \ (co group) -> (sink) + * \ / + * \-> (reduce) - / + * </pre> + */ +@SuppressWarnings("serial") +public class DataExchangeModeClosedBranchingTest extends CompilerTestBase { + + @Test + public void testPipelinedForced() { + // PIPELINED_FORCED should result in pipelining all the way + verifyBranchingJoiningPlan(ExecutionMode.PIPELINED_FORCED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED, + DataExchangeMode.PIPELINED, DataExchangeMode.PIPELINED); + } + + @Test + public void testPipelined() { + // PIPELINED should result in pipelining all the way + verifyBranchingJoiningPlan(ExecutionMode.PIPELINED, + DataExchangeMode.PIPELINED, // to map + DataExchangeMode.PIPELINED, // to combiner connections are pipelined + DataExchangeMode.BATCH, // to reduce + DataExchangeMode.BATCH, // to filter + DataExchangeMode.PIPELINED, // to sink after reduce + DataExchangeMode.PIPELINED, // to join (first input) + DataExchangeMode.BATCH, // to join (second input) + DataExchangeMode.PIPELINED, // combiner connections are pipelined + DataExchangeMode.BATCH, // to other reducer + DataExchangeMode.PIPELINED, // to flatMap + DataExchangeMode.PIPELINED, // to sink after flatMap + DataExchangeMode.PIPELINED, // to coGroup (first input) + DataExchangeMode.PIPELINED, // to coGroup (second input) + DataExchangeMode.PIPELINED // to sink after coGroup + ); + } + + @Test + public void testBatch() { + // BATCH should result in batching the shuffle all the way + verifyBranchingJoiningPlan(ExecutionMode.BATCH, + DataExchangeMode.PIPELINED, // to map + DataExchangeMode.PIPELINED, // to combiner connections are pipelined + DataExchangeMode.BATCH, // to reduce + DataExchangeMode.BATCH, // to filter + DataExchangeMode.PIPELINED, // to sink after reduce + DataExchangeMode.BATCH, // to join (first input) + DataExchangeMode.BATCH, // to join (second input) + DataExchangeMode.PIPELINED, // combiner connections are pipelined + DataExchangeMode.BATCH, // to other reducer + DataExchangeMode.PIPELINED, // to flatMap + DataExchangeMode.PIPELINED, // to sink after flatMap + DataExchangeMode.BATCH, // to coGroup (first input) + DataExchangeMode.BATCH, // to coGroup (second input) + DataExchangeMode.PIPELINED // to sink after coGroup + ); + } + + @Test + public void testBatchForced() { + // BATCH_FORCED should result in batching all the way + verifyBranchingJoiningPlan(ExecutionMode.BATCH_FORCED, + DataExchangeMode.BATCH, // to map + DataExchangeMode.PIPELINED, // to combiner connections are pipelined + DataExchangeMode.BATCH, // to reduce + DataExchangeMode.BATCH, // to filter + DataExchangeMode.BATCH, // to sink after reduce + DataExchangeMode.BATCH, // to join (first input) + DataExchangeMode.BATCH, // to join (second input) + DataExchangeMode.PIPELINED, // combiner connections are pipelined + DataExchangeMode.BATCH, // to other reducer + DataExchangeMode.BATCH, // to flatMap + DataExchangeMode.BATCH, // to sink after flatMap + DataExchangeMode.BATCH, // to coGroup (first input) + DataExchangeMode.BATCH, // to coGroup (second input) + DataExchangeMode.BATCH // to sink after coGroup + ); + } + + private void verifyBranchingJoiningPlan(ExecutionMode execMode, + DataExchangeMode toMap, + DataExchangeMode toReduceCombiner, + DataExchangeMode toReduce, + DataExchangeMode toFilter, + DataExchangeMode toReduceSink, + DataExchangeMode toJoin1, + DataExchangeMode toJoin2, + DataExchangeMode toOtherReduceCombiner, + DataExchangeMode toOtherReduce, + DataExchangeMode toFlatMap, + DataExchangeMode toFlatMapSink, + DataExchangeMode toCoGroup1, + DataExchangeMode toCoGroup2, + DataExchangeMode toCoGroupSink) + { + try { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + env.getConfig().setExecutionMode(execMode); + + DataSet<Tuple2<Long, Long>> data = env.fromElements(33L, 44L) + .map(new MapFunction<Long, Tuple2<Long, Long>>() { + @Override + public Tuple2<Long, Long> map(Long value) { + return new Tuple2<Long, Long>(value, value); + } + }); + + DataSet<Tuple2<Long, Long>> reduced = data.groupBy(0).reduce(new SelectOneReducer<Tuple2<Long, Long>>()); + reduced.output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("reduceSink"); + + DataSet<Tuple2<Long, Long>> filtered = data.filter(new FilterFunction<Tuple2<Long, Long>>() { + @Override + public boolean filter(Tuple2<Long, Long> value) throws Exception { + return false; + } + }); + + DataSet<Tuple2<Long, Long>> joined = reduced.join(filtered) + .where(1).equalTo(1) + .with(new DummyFlatJoinFunction<Tuple2<Long, Long>>()); + + joined.flatMap(new IdentityFlatMapper<Tuple2<Long, Long>>()) + .output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("flatMapSink"); + + joined.coGroup(filtered.groupBy(1).reduceGroup(new Top1GroupReducer<Tuple2<Long, Long>>())) + .where(0).equalTo(0) + .with(new DummyCoGroupFunction<Tuple2<Long, Long>, Tuple2<Long, Long>>()) + .output(new DiscardingOutputFormat<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>>()).name("cgSink"); + + + OptimizedPlan optPlan = compileNoStats(env.createProgramPlan()); + + SinkPlanNode reduceSink = findSink(optPlan.getDataSinks(), "reduceSink"); + SinkPlanNode flatMapSink = findSink(optPlan.getDataSinks(), "flatMapSink"); + SinkPlanNode cgSink = findSink(optPlan.getDataSinks(), "cgSink"); + + DualInputPlanNode coGroupNode = (DualInputPlanNode) cgSink.getPredecessor(); + + DualInputPlanNode joinNode = (DualInputPlanNode) coGroupNode.getInput1().getSource(); + SingleInputPlanNode otherReduceNode = (SingleInputPlanNode) coGroupNode.getInput2().getSource(); + SingleInputPlanNode otherReduceCombinerNode = (SingleInputPlanNode) otherReduceNode.getPredecessor(); + + SingleInputPlanNode reduceNode = (SingleInputPlanNode) joinNode.getInput1().getSource(); + SingleInputPlanNode reduceCombinerNode = (SingleInputPlanNode) reduceNode.getPredecessor(); + assertEquals(reduceNode, reduceSink.getPredecessor()); + + SingleInputPlanNode filterNode = (SingleInputPlanNode) joinNode.getInput2().getSource(); + assertEquals(filterNode, otherReduceCombinerNode.getPredecessor()); + + SingleInputPlanNode mapNode = (SingleInputPlanNode) filterNode.getPredecessor(); + assertEquals(mapNode, reduceCombinerNode.getPredecessor()); + + SingleInputPlanNode flatMapNode = (SingleInputPlanNode) flatMapSink.getPredecessor(); + assertEquals(joinNode, flatMapNode.getPredecessor()); + + // verify the data exchange modes + + assertEquals(toReduceSink, reduceSink.getInput().getDataExchangeMode()); + assertEquals(toFlatMapSink, flatMapSink.getInput().getDataExchangeMode()); + assertEquals(toCoGroupSink, cgSink.getInput().getDataExchangeMode()); + + assertEquals(toCoGroup1, coGroupNode.getInput1().getDataExchangeMode()); + assertEquals(toCoGroup2, coGroupNode.getInput2().getDataExchangeMode()); + + assertEquals(toJoin1, joinNode.getInput1().getDataExchangeMode()); + assertEquals(toJoin2, joinNode.getInput2().getDataExchangeMode()); + + assertEquals(toOtherReduce, otherReduceNode.getInput().getDataExchangeMode()); + assertEquals(toOtherReduceCombiner, otherReduceCombinerNode.getInput().getDataExchangeMode()); + + assertEquals(toFlatMap, flatMapNode.getInput().getDataExchangeMode()); + + assertEquals(toFilter, filterNode.getInput().getDataExchangeMode()); + assertEquals(toReduce, reduceNode.getInput().getDataExchangeMode()); + assertEquals(toReduceCombiner, reduceCombinerNode.getInput().getDataExchangeMode()); + + assertEquals(toMap, mapNode.getInput().getDataExchangeMode()); + } + catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private SinkPlanNode findSink(Collection<SinkPlanNode> collection, String name) { + for (SinkPlanNode node : collection) { + String nodeName = node.getOptimizerNode().getOperator().getName(); + if (nodeName != null && nodeName.equals(name)) { + return node; + } + } + + throw new IllegalArgumentException("No node with that name was found."); + } +}