http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/GenericFlatTypePostPass.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/GenericFlatTypePostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/GenericFlatTypePostPass.java new file mode 100644 index 0000000..2d8377e --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/GenericFlatTypePostPass.java @@ -0,0 +1,579 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.optimizer.postpass; + +import java.util.Map; + +import org.apache.flink.api.common.operators.SemanticProperties; +import org.apache.flink.api.common.operators.util.FieldList; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.CompilerPostPassException; +import org.apache.flink.optimizer.dag.OptimizerNode; +import org.apache.flink.optimizer.dag.SingleInputNode; +import org.apache.flink.optimizer.dag.TwoInputNode; +import org.apache.flink.optimizer.dag.WorksetIterationNode; +import org.apache.flink.optimizer.plan.BulkIterationPlanNode; +import org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.NAryUnionPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.plan.SolutionSetPlanNode; +import org.apache.flink.optimizer.plan.SourcePlanNode; +import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; +import org.apache.flink.optimizer.plan.WorksetPlanNode; + +/** + * + */ +public abstract class GenericFlatTypePostPass<X, T extends AbstractSchema<X>> implements OptimizerPostPass { + + private boolean propagateParentSchemaDown = true; + + + public boolean isPropagateParentSchemaDown() { + return propagateParentSchemaDown; + } + + public void setPropagateParentSchemaDown(boolean propagateParentSchemaDown) { + this.propagateParentSchemaDown = propagateParentSchemaDown; + } + + // -------------------------------------------------------------------------------------------- + // Generic schema inferring traversal + // -------------------------------------------------------------------------------------------- + + @Override + public void postPass(OptimizedPlan plan) { + for (SinkPlanNode sink : plan.getDataSinks()) { + traverse(sink, null, true); + } + } + + @SuppressWarnings("unchecked") + protected void traverse(PlanNode node, T parentSchema, boolean createUtilities) { + // distinguish the node types + if (node instanceof SinkPlanNode) { + SinkPlanNode sn = (SinkPlanNode) node; + Channel inchannel = sn.getInput(); + + T schema = createEmptySchema(); + sn.postPassHelper = schema; + + // add the sinks information to the schema + try { + getSinkSchema(sn, schema); + } + catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException("Conflicting type infomation for the data sink '" + + sn.getSinkNode().getOperator().getName() + "'."); + } + + // descend to the input channel + try { + propagateToChannel(schema, inchannel, createUtilities); + } + catch (MissingFieldTypeInfoException ex) { + throw new CompilerPostPassException("Missing type infomation for the channel that inputs to the data sink '" + + sn.getSinkNode().getOperator().getName() + "'."); + } + } + else if (node instanceof SourcePlanNode) { + if (createUtilities) { + ((SourcePlanNode) node).setSerializer(createSerializer(parentSchema, node)); + // nothing else to be done here. the source has no input and no strategy itself + } + } + else if (node instanceof BulkIterationPlanNode) { + BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node; + + // get the nodes current schema + T schema; + if (iterationNode.postPassHelper == null) { + schema = createEmptySchema(); + iterationNode.postPassHelper = schema; + } else { + schema = (T) iterationNode.postPassHelper; + } + schema.increaseNumConnectionsThatContributed(); + + // add the parent schema to the schema + if (propagateParentSchemaDown) { + addSchemaToSchema(parentSchema, schema, iterationNode.getProgramOperator().getName()); + } + + // check whether all outgoing channels have not yet contributed. come back later if not. + if (schema.getNumConnectionsThatContributed() < iterationNode.getOutgoingChannels().size()) { + return; + } + + if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node."); + } + + // traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion + if (iterationNode.getRootOfTerminationCriterion() != null) { + SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion(); + traverse(addMapper.getInput().getSource(), createEmptySchema(), false); + try { + addMapper.getInput().setSerializer(createSerializer(createEmptySchema())); + } catch (MissingFieldTypeInfoException e) { + throw new RuntimeException(e); + } + } + + // traverse the step function for the first time. create schema only, no utilities + traverse(iterationNode.getRootOfStepFunction(), schema, false); + + T pss = (T) iterationNode.getPartialSolutionPlanNode().postPassHelper; + if (pss == null) { + throw new CompilerException("Error in Optimizer Post Pass: Partial solution schema is null after first traversal of the step function."); + } + + // traverse the step function for the second time, taking the schema of the partial solution + traverse(iterationNode.getRootOfStepFunction(), pss, createUtilities); + + if (iterationNode.getRootOfTerminationCriterion() != null) { + SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion(); + traverse(addMapper.getInput().getSource(), createEmptySchema(), createUtilities); + try { + addMapper.getInput().setSerializer(createSerializer(createEmptySchema())); + } catch (MissingFieldTypeInfoException e) { + throw new RuntimeException(e); + } + } + + // take the schema from the partial solution node and add its fields to the iteration result schema. + // input and output schema need to be identical, so this is essentially a sanity check + addSchemaToSchema(pss, schema, iterationNode.getProgramOperator().getName()); + + // set the serializer + if (createUtilities) { + iterationNode.setSerializerForIterationChannel(createSerializer(pss, iterationNode.getPartialSolutionPlanNode())); + } + + // done, we can now propagate our info down + try { + propagateToChannel(schema, iterationNode.getInput(), createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for input channel to node '" + + iterationNode.getProgramOperator().getName() + "'. Missing type information for key field " + + e.getFieldNumber()); + } + } + else if (node instanceof WorksetIterationPlanNode) { + WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node; + + // get the nodes current schema + T schema; + if (iterationNode.postPassHelper == null) { + schema = createEmptySchema(); + iterationNode.postPassHelper = schema; + } else { + schema = (T) iterationNode.postPassHelper; + } + schema.increaseNumConnectionsThatContributed(); + + // add the parent schema to the schema (which refers to the solution set schema) + if (propagateParentSchemaDown) { + addSchemaToSchema(parentSchema, schema, iterationNode.getProgramOperator().getName()); + } + + // check whether all outgoing channels have not yet contributed. come back later if not. + if (schema.getNumConnectionsThatContributed() < iterationNode.getOutgoingChannels().size()) { + return; + } + if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node."); + } + if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node."); + } + + // traverse the step function + // pass an empty schema to the next workset and the parent schema to the solution set delta + // these first traversals are schema only + traverse(iterationNode.getNextWorkSetPlanNode(), createEmptySchema(), false); + traverse(iterationNode.getSolutionSetDeltaPlanNode(), schema, false); + + T wss = (T) iterationNode.getWorksetPlanNode().postPassHelper; + T sss = (T) iterationNode.getSolutionSetPlanNode().postPassHelper; + + if (wss == null) { + throw new CompilerException("Error in Optimizer Post Pass: Workset schema is null after first traversal of the step function."); + } + if (sss == null) { + throw new CompilerException("Error in Optimizer Post Pass: Solution set schema is null after first traversal of the step function."); + } + + // make the second pass and instantiate the utilities + traverse(iterationNode.getNextWorkSetPlanNode(), wss, createUtilities); + traverse(iterationNode.getSolutionSetDeltaPlanNode(), sss, createUtilities); + + // add the types from the solution set schema to the iteration's own schema. since + // the solution set input and the result must have the same schema, this acts as a sanity check. + try { + for (Map.Entry<Integer, X> entry : sss) { + Integer pos = entry.getKey(); + schema.addType(pos, entry.getValue()); + } + } catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException("Conflicting type information for field " + e.getFieldNumber() + + " in node '" + iterationNode.getProgramOperator().getName() + "'. Contradicting types between the " + + "result of the iteration and the solution set schema: " + e.getPreviousType() + + " and " + e.getNewType() + ". Most probable cause: Invalid constant field annotations."); + } + + // set the serializers and comparators + if (createUtilities) { + WorksetIterationNode optNode = iterationNode.getIterationNode(); + iterationNode.setWorksetSerializer(createSerializer(wss, iterationNode.getWorksetPlanNode())); + iterationNode.setSolutionSetSerializer(createSerializer(sss, iterationNode.getSolutionSetPlanNode())); + try { + iterationNode.setSolutionSetComparator(createComparator(optNode.getSolutionSetKeyFields(), null, sss)); + } catch (MissingFieldTypeInfoException ex) { + throw new CompilerPostPassException("Could not set up the solution set for workset iteration '" + + optNode.getOperator().getName() + "'. Missing type information for key field " + ex.getFieldNumber() + '.'); + } + } + + // done, we can now propagate our info down + try { + propagateToChannel(schema, iterationNode.getInitialSolutionSetInput(), createUtilities); + propagateToChannel(wss, iterationNode.getInitialWorksetInput(), createUtilities); + } catch (MissingFieldTypeInfoException ex) { + throw new CompilerPostPassException("Could not set up runtime strategy for input channel to node '" + + iterationNode.getProgramOperator().getName() + "'. Missing type information for key field " + + ex.getFieldNumber()); + } + } + else if (node instanceof SingleInputPlanNode) { + SingleInputPlanNode sn = (SingleInputPlanNode) node; + + // get the nodes current schema + T schema; + if (sn.postPassHelper == null) { + schema = createEmptySchema(); + sn.postPassHelper = schema; + } else { + schema = (T) sn.postPassHelper; + } + schema.increaseNumConnectionsThatContributed(); + SingleInputNode optNode = sn.getSingleInputNode(); + + // add the parent schema to the schema + if (propagateParentSchemaDown) { + addSchemaToSchema(parentSchema, schema, optNode, 0); + } + + // check whether all outgoing channels have not yet contributed. come back later if not. + if (schema.getNumConnectionsThatContributed() < sn.getOutgoingChannels().size()) { + return; + } + + // add the nodes local information + try { + getSingleInputNodeSchema(sn, schema); + } catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException(getConflictingTypeErrorMessage(e, optNode.getOperator().getName())); + } + + if (createUtilities) { + // parameterize the node's driver strategy + for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) { + try { + sn.setComparator(createComparator(sn.getKeys(i), sn.getSortOrders(i), schema),i); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for node '" + + optNode.getOperator().getName() + "'. Missing type information for key field " + + e.getFieldNumber()); + } + } + } + + // done, we can now propagate our info down + try { + propagateToChannel(schema, sn.getInput(), createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for input channel to node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + + // don't forget the broadcast inputs + for (Channel c: sn.getBroadcastInputs()) { + try { + propagateToChannel(createEmptySchema(), c, createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for broadcast channel in node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + } + } + else if (node instanceof DualInputPlanNode) { + DualInputPlanNode dn = (DualInputPlanNode) node; + + // get the nodes current schema + T schema1; + T schema2; + if (dn.postPassHelper1 == null) { + schema1 = createEmptySchema(); + schema2 = createEmptySchema(); + dn.postPassHelper1 = schema1; + dn.postPassHelper2 = schema2; + } else { + schema1 = (T) dn.postPassHelper1; + schema2 = (T) dn.postPassHelper2; + } + + schema1.increaseNumConnectionsThatContributed(); + schema2.increaseNumConnectionsThatContributed(); + TwoInputNode optNode = dn.getTwoInputNode(); + + // add the parent schema to the schema + if (propagateParentSchemaDown) { + addSchemaToSchema(parentSchema, schema1, optNode, 0); + addSchemaToSchema(parentSchema, schema2, optNode, 1); + } + + // check whether all outgoing channels have not yet contributed. come back later if not. + if (schema1.getNumConnectionsThatContributed() < dn.getOutgoingChannels().size()) { + return; + } + + // add the nodes local information + try { + getDualInputNodeSchema(dn, schema1, schema2); + } catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException(getConflictingTypeErrorMessage(e, optNode.getOperator().getName())); + } + + // parameterize the node's driver strategy + if (createUtilities) { + if (dn.getDriverStrategy().getNumRequiredComparators() > 0) { + // set the individual comparators + try { + dn.setComparator1(createComparator(dn.getKeysForInput1(), dn.getSortOrders(), schema1)); + dn.setComparator2(createComparator(dn.getKeysForInput2(), dn.getSortOrders(), schema2)); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + + // set the pair comparator + try { + dn.setPairComparator(createPairComparator(dn.getKeysForInput1(), dn.getKeysForInput2(), + dn.getSortOrders(), schema1, schema2)); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + + } + } + + // done, we can now propagate our info down + try { + propagateToChannel(schema1, dn.getInput1(), createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for the first input channel to node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + try { + propagateToChannel(schema2, dn.getInput2(), createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for the second input channel to node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + + // don't forget the broadcast inputs + for (Channel c: dn.getBroadcastInputs()) { + try { + propagateToChannel(createEmptySchema(), c, createUtilities); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Could not set up runtime strategy for broadcast channel in node '" + + optNode.getOperator().getName() + "'. Missing type information for field " + e.getFieldNumber()); + } + } + } + else if (node instanceof NAryUnionPlanNode) { + // only propagate the info down + try { + for (Channel channel : node.getInputs()) { + propagateToChannel(parentSchema, channel, createUtilities); + } + } catch (MissingFieldTypeInfoException ex) { + throw new CompilerPostPassException("Could not set up runtime strategy for the input channel to " + + " a union node. Missing type information for field " + ex.getFieldNumber()); + } + } + // catch the sources of the iterative step functions + else if (node instanceof BulkPartialSolutionPlanNode || + node instanceof SolutionSetPlanNode || + node instanceof WorksetPlanNode) + { + // get the nodes current schema + T schema; + String name; + if (node instanceof BulkPartialSolutionPlanNode) { + BulkPartialSolutionPlanNode psn = (BulkPartialSolutionPlanNode) node; + if (psn.postPassHelper == null) { + schema = createEmptySchema(); + psn.postPassHelper = schema; + } else { + schema = (T) psn.postPassHelper; + } + name = "partial solution of bulk iteration '" + + psn.getPartialSolutionNode().getIterationNode().getOperator().getName() + "'"; + } + else if (node instanceof SolutionSetPlanNode) { + SolutionSetPlanNode ssn = (SolutionSetPlanNode) node; + if (ssn.postPassHelper == null) { + schema = createEmptySchema(); + ssn.postPassHelper = schema; + } else { + schema = (T) ssn.postPassHelper; + } + name = "solution set of workset iteration '" + + ssn.getSolutionSetNode().getIterationNode().getOperator().getName() + "'"; + } + else if (node instanceof WorksetPlanNode) { + WorksetPlanNode wsn = (WorksetPlanNode) node; + if (wsn.postPassHelper == null) { + schema = createEmptySchema(); + wsn.postPassHelper = schema; + } else { + schema = (T) wsn.postPassHelper; + } + name = "workset of workset iteration '" + + wsn.getWorksetNode().getIterationNode().getOperator().getName() + "'"; + } else { + throw new CompilerException(); + } + + schema.increaseNumConnectionsThatContributed(); + + // add the parent schema to the schema + addSchemaToSchema(parentSchema, schema, name); + } + else { + throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName()); + } + } + + private void propagateToChannel(T schema, Channel channel, boolean createUtilities) throws MissingFieldTypeInfoException { + if (createUtilities) { + // the serializer always exists + channel.setSerializer(createSerializer(schema)); + + // parameterize the ship strategy + if (channel.getShipStrategy().requiresComparator()) { + channel.setShipStrategyComparator( + createComparator(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder(), schema)); + } + + // parameterize the local strategy + if (channel.getLocalStrategy().requiresComparator()) { + channel.setLocalStrategyComparator( + createComparator(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder(), schema)); + } + } + + // propagate the channel's source model + traverse(channel.getSource(), schema, createUtilities); + } + + private void addSchemaToSchema(T sourceSchema, T targetSchema, String opName) { + try { + for (Map.Entry<Integer, X> entry : sourceSchema) { + Integer pos = entry.getKey(); + targetSchema.addType(pos, entry.getValue()); + } + } catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException("Conflicting type information for field " + e.getFieldNumber() + + " in node '" + opName + "' propagated from successor node. " + + "Conflicting types: " + e.getPreviousType() + " and " + e.getNewType() + + ". Most probable cause: Invalid constant field annotations."); + } + } + + private void addSchemaToSchema(T sourceSchema, T targetSchema, OptimizerNode optNode, int input) { + try { + for (Map.Entry<Integer, X> entry : sourceSchema) { + Integer pos = entry.getKey(); + SemanticProperties sprops = optNode.getSemanticProperties(); + + if (sprops != null && sprops.getForwardingTargetFields(input, pos) != null && sprops.getForwardingTargetFields(input, pos).contains(pos)) { + targetSchema.addType(pos, entry.getValue()); + } + } + } catch (ConflictingFieldTypeInfoException e) { + throw new CompilerPostPassException("Conflicting type information for field " + e.getFieldNumber() + + " in node '" + optNode.getOperator().getName() + "' propagated from successor node. " + + "Conflicting types: " + e.getPreviousType() + " and " + e.getNewType() + + ". Most probable cause: Invalid constant field annotations."); + } + } + + private String getConflictingTypeErrorMessage(ConflictingFieldTypeInfoException e, String operatorName) { + return "Conflicting type information for field " + e.getFieldNumber() + + " in node '" + operatorName + "' between types declared in the node's " + + "contract and types inferred from successor contracts. Conflicting types: " + + e.getPreviousType() + " and " + e.getNewType() + + ". Most probable cause: Invalid constant field annotations."; + } + + private TypeSerializerFactory<?> createSerializer(T schema, PlanNode node) { + try { + return createSerializer(schema); + } catch (MissingFieldTypeInfoException e) { + throw new CompilerPostPassException("Missing type information while creating serializer for '" + + node.getProgramOperator().getName() + "'."); + } + } + + // -------------------------------------------------------------------------------------------- + // Type specific methods that extract schema information + // -------------------------------------------------------------------------------------------- + + protected abstract T createEmptySchema(); + + protected abstract void getSinkSchema(SinkPlanNode sink, T schema) throws CompilerPostPassException, ConflictingFieldTypeInfoException; + + protected abstract void getSingleInputNodeSchema(SingleInputPlanNode node, T schema) throws CompilerPostPassException, ConflictingFieldTypeInfoException; + + protected abstract void getDualInputNodeSchema(DualInputPlanNode node, T input1Schema, T input2Schema) throws CompilerPostPassException, ConflictingFieldTypeInfoException; + + // -------------------------------------------------------------------------------------------- + // Methods to create serializers and comparators + // -------------------------------------------------------------------------------------------- + + protected abstract TypeSerializerFactory<?> createSerializer(T schema) throws MissingFieldTypeInfoException; + + protected abstract TypeComparatorFactory<?> createComparator(FieldList fields, boolean[] directions, T schema) throws MissingFieldTypeInfoException; + + protected abstract TypePairComparatorFactory<?, ?> createPairComparator(FieldList fields1, FieldList fields2, boolean[] sortDirections, + T schema1, T schema2) throws MissingFieldTypeInfoException; +}
http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java new file mode 100644 index 0000000..5fdf3dd --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/JavaApiPostPass.java @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.postpass; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.operators.DualInputOperator; +import org.apache.flink.api.common.operators.GenericDataSourceBase; +import org.apache.flink.api.common.operators.Operator; +import org.apache.flink.api.common.operators.SingleInputOperator; +import org.apache.flink.api.common.operators.base.BulkIterationBase; +import org.apache.flink.api.common.operators.base.DeltaIterationBase; +import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; +import org.apache.flink.api.common.operators.util.FieldList; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeComparatorFactory; +import org.apache.flink.api.common.typeutils.TypePairComparatorFactory; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.api.java.operators.translation.PlanUnwrappingReduceGroupOperator; +import org.apache.flink.api.java.tuple.Tuple; +import org.apache.flink.api.java.typeutils.PojoTypeInfo; +import org.apache.flink.api.java.typeutils.runtime.RuntimeComparatorFactory; +import org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory; +import org.apache.flink.api.java.typeutils.runtime.RuntimeSerializerFactory; +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.CompilerPostPassException; +import org.apache.flink.optimizer.plan.BulkIterationPlanNode; +import org.apache.flink.optimizer.plan.BulkPartialSolutionPlanNode; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.NAryUnionPlanNode; +import org.apache.flink.optimizer.plan.OptimizedPlan; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.optimizer.plan.SolutionSetPlanNode; +import org.apache.flink.optimizer.plan.SourcePlanNode; +import org.apache.flink.optimizer.plan.WorksetIterationPlanNode; +import org.apache.flink.optimizer.plan.WorksetPlanNode; +import org.apache.flink.optimizer.util.NoOpUnaryUdfOp; +import org.apache.flink.runtime.operators.DriverStrategy; + +/** + * The post-optimizer plan traversal. This traversal fills in the API specific utilities (serializers and + * comparators). + */ +public class JavaApiPostPass implements OptimizerPostPass { + + private final Set<PlanNode> alreadyDone = new HashSet<PlanNode>(); + + private ExecutionConfig executionConfig = null; + + @Override + public void postPass(OptimizedPlan plan) { + + executionConfig = plan.getOriginalPactPlan().getExecutionConfig(); + + for (SinkPlanNode sink : plan.getDataSinks()) { + traverse(sink); + } + } + + + protected void traverse(PlanNode node) { + if (!alreadyDone.add(node)) { + // already worked on that one + return; + } + + // distinguish the node types + if (node instanceof SinkPlanNode) { + // descend to the input channel + SinkPlanNode sn = (SinkPlanNode) node; + Channel inchannel = sn.getInput(); + traverseChannel(inchannel); + } + else if (node instanceof SourcePlanNode) { + TypeInformation<?> typeInfo = getTypeInfoFromSource((SourcePlanNode) node); + ((SourcePlanNode) node).setSerializer(createSerializer(typeInfo)); + } + else if (node instanceof BulkIterationPlanNode) { + BulkIterationPlanNode iterationNode = (BulkIterationPlanNode) node; + + if (iterationNode.getRootOfStepFunction() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile an iteration step function where next partial solution is created by a Union node."); + } + + // traverse the termination criterion for the first time. create schema only, no utilities. Needed in case of intermediate termination criterion + if (iterationNode.getRootOfTerminationCriterion() != null) { + SingleInputPlanNode addMapper = (SingleInputPlanNode) iterationNode.getRootOfTerminationCriterion(); + traverseChannel(addMapper.getInput()); + } + + BulkIterationBase<?> operator = (BulkIterationBase<?>) iterationNode.getProgramOperator(); + + // set the serializer + iterationNode.setSerializerForIterationChannel(createSerializer(operator.getOperatorInfo().getOutputType())); + + // done, we can now propagate our info down + traverseChannel(iterationNode.getInput()); + traverse(iterationNode.getRootOfStepFunction()); + } + else if (node instanceof WorksetIterationPlanNode) { + WorksetIterationPlanNode iterationNode = (WorksetIterationPlanNode) node; + + if (iterationNode.getNextWorkSetPlanNode() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile a workset iteration step function where the next workset is produced by a Union node."); + } + if (iterationNode.getSolutionSetDeltaPlanNode() instanceof NAryUnionPlanNode) { + throw new CompilerException("Optimizer cannot compile a workset iteration step function where the solution set delta is produced by a Union node."); + } + + DeltaIterationBase<?, ?> operator = (DeltaIterationBase<?, ?>) iterationNode.getProgramOperator(); + + // set the serializers and comparators for the workset iteration + iterationNode.setSolutionSetSerializer(createSerializer(operator.getOperatorInfo().getFirstInputType())); + iterationNode.setWorksetSerializer(createSerializer(operator.getOperatorInfo().getSecondInputType())); + iterationNode.setSolutionSetComparator(createComparator(operator.getOperatorInfo().getFirstInputType(), + iterationNode.getSolutionSetKeyFields(), getSortOrders(iterationNode.getSolutionSetKeyFields(), null))); + + // traverse the inputs + traverseChannel(iterationNode.getInput1()); + traverseChannel(iterationNode.getInput2()); + + // traverse the step function + traverse(iterationNode.getSolutionSetDeltaPlanNode()); + traverse(iterationNode.getNextWorkSetPlanNode()); + } + else if (node instanceof SingleInputPlanNode) { + SingleInputPlanNode sn = (SingleInputPlanNode) node; + + if (!(sn.getOptimizerNode().getOperator() instanceof SingleInputOperator)) { + + // Special case for delta iterations + if(sn.getOptimizerNode().getOperator() instanceof NoOpUnaryUdfOp) { + traverseChannel(sn.getInput()); + return; + } else { + throw new RuntimeException("Wrong operator type found in post pass."); + } + } + + SingleInputOperator<?, ?, ?> singleInputOperator = (SingleInputOperator<?, ?, ?>) sn.getOptimizerNode().getOperator(); + + // parameterize the node's driver strategy + for(int i=0;i<sn.getDriverStrategy().getNumRequiredComparators();i++) { + sn.setComparator(createComparator(singleInputOperator.getOperatorInfo().getInputType(), sn.getKeys(i), + getSortOrders(sn.getKeys(i), sn.getSortOrders(i))), i); + } + // done, we can now propagate our info down + traverseChannel(sn.getInput()); + + // don't forget the broadcast inputs + for (Channel c: sn.getBroadcastInputs()) { + traverseChannel(c); + } + } + else if (node instanceof DualInputPlanNode) { + DualInputPlanNode dn = (DualInputPlanNode) node; + + if (!(dn.getOptimizerNode().getOperator() instanceof DualInputOperator)) { + throw new RuntimeException("Wrong operator type found in post pass."); + } + + DualInputOperator<?, ?, ?, ?> dualInputOperator = (DualInputOperator<?, ?, ?, ?>) dn.getOptimizerNode().getOperator(); + + // parameterize the node's driver strategy + if (dn.getDriverStrategy().getNumRequiredComparators() > 0) { + dn.setComparator1(createComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), dn.getKeysForInput1(), + getSortOrders(dn.getKeysForInput1(), dn.getSortOrders()))); + dn.setComparator2(createComparator(dualInputOperator.getOperatorInfo().getSecondInputType(), dn.getKeysForInput2(), + getSortOrders(dn.getKeysForInput2(), dn.getSortOrders()))); + + dn.setPairComparator(createPairComparator(dualInputOperator.getOperatorInfo().getFirstInputType(), + dualInputOperator.getOperatorInfo().getSecondInputType())); + + } + + traverseChannel(dn.getInput1()); + traverseChannel(dn.getInput2()); + + // don't forget the broadcast inputs + for (Channel c: dn.getBroadcastInputs()) { + traverseChannel(c); + } + + } + // catch the sources of the iterative step functions + else if (node instanceof BulkPartialSolutionPlanNode || + node instanceof SolutionSetPlanNode || + node instanceof WorksetPlanNode) + { + // Do nothing :D + } + else if (node instanceof NAryUnionPlanNode){ + // Traverse to all child channels + for (Channel channel : node.getInputs()) { + traverseChannel(channel); + } + } + else { + throw new CompilerPostPassException("Unknown node type encountered: " + node.getClass().getName()); + } + } + + private void traverseChannel(Channel channel) { + + PlanNode source = channel.getSource(); + Operator<?> javaOp = source.getProgramOperator(); + +// if (!(javaOp instanceof BulkIteration) && !(javaOp instanceof JavaPlanNode)) { +// throw new RuntimeException("Wrong operator type found in post pass: " + javaOp); +// } + + TypeInformation<?> type = javaOp.getOperatorInfo().getOutputType(); + + + if(javaOp instanceof GroupReduceOperatorBase && + (source.getDriverStrategy() == DriverStrategy.SORTED_GROUP_COMBINE || source.getDriverStrategy() == DriverStrategy.ALL_GROUP_REDUCE_COMBINE)) { + GroupReduceOperatorBase<?, ?, ?> groupNode = (GroupReduceOperatorBase<?, ?, ?>) javaOp; + type = groupNode.getInput().getOperatorInfo().getOutputType(); + } + else if(javaOp instanceof PlanUnwrappingReduceGroupOperator && + source.getDriverStrategy().equals(DriverStrategy.SORTED_GROUP_COMBINE)) { + PlanUnwrappingReduceGroupOperator<?, ?, ?> groupNode = (PlanUnwrappingReduceGroupOperator<?, ?, ?>) javaOp; + type = groupNode.getInput().getOperatorInfo().getOutputType(); + } + + // the serializer always exists + channel.setSerializer(createSerializer(type)); + + // parameterize the ship strategy + if (channel.getShipStrategy().requiresComparator()) { + channel.setShipStrategyComparator(createComparator(type, channel.getShipStrategyKeys(), + getSortOrders(channel.getShipStrategyKeys(), channel.getShipStrategySortOrder()))); + } + + // parameterize the local strategy + if (channel.getLocalStrategy().requiresComparator()) { + channel.setLocalStrategyComparator(createComparator(type, channel.getLocalStrategyKeys(), + getSortOrders(channel.getLocalStrategyKeys(), channel.getLocalStrategySortOrder()))); + } + + // descend to the channel's source + traverse(channel.getSource()); + } + + + @SuppressWarnings("unchecked") + private static <T> TypeInformation<T> getTypeInfoFromSource(SourcePlanNode node) { + Operator<?> op = node.getOptimizerNode().getOperator(); + + if (op instanceof GenericDataSourceBase) { + return ((GenericDataSourceBase<T, ?>) op).getOperatorInfo().getOutputType(); + } else { + throw new RuntimeException("Wrong operator type found in post pass."); + } + } + + private <T> TypeSerializerFactory<?> createSerializer(TypeInformation<T> typeInfo) { + TypeSerializer<T> serializer = typeInfo.createSerializer(executionConfig); + + return new RuntimeSerializerFactory<T>(serializer, typeInfo.getTypeClass()); + } + + @SuppressWarnings("unchecked") + private <T> TypeComparatorFactory<?> createComparator(TypeInformation<T> typeInfo, FieldList keys, boolean[] sortOrder) { + + TypeComparator<T> comparator; + if (typeInfo instanceof CompositeType) { + comparator = ((CompositeType<T>) typeInfo).createComparator(keys.toArray(), sortOrder, 0, executionConfig); + } + else if (typeInfo instanceof AtomicType) { + // handle grouping of atomic types + comparator = ((AtomicType<T>) typeInfo).createComparator(sortOrder[0], executionConfig); + } + else { + throw new RuntimeException("Unrecognized type: " + typeInfo); + } + + return new RuntimeComparatorFactory<T>(comparator); + } + + private static <T1 extends Tuple, T2 extends Tuple> TypePairComparatorFactory<T1,T2> createPairComparator(TypeInformation<?> typeInfo1, TypeInformation<?> typeInfo2) { + if (!(typeInfo1.isTupleType() || typeInfo1 instanceof PojoTypeInfo) && (typeInfo2.isTupleType() || typeInfo2 instanceof PojoTypeInfo)) { + throw new RuntimeException("The runtime currently supports only keyed binary operations (such as joins) on tuples and POJO types."); + } + +// @SuppressWarnings("unchecked") +// TupleTypeInfo<T1> info1 = (TupleTypeInfo<T1>) typeInfo1; +// @SuppressWarnings("unchecked") +// TupleTypeInfo<T2> info2 = (TupleTypeInfo<T2>) typeInfo2; + + return new RuntimePairComparatorFactory<T1,T2>(); + } + + private static final boolean[] getSortOrders(FieldList keys, boolean[] orders) { + if (orders == null) { + orders = new boolean[keys.size()]; + Arrays.fill(orders, true); + } + return orders; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/MissingFieldTypeInfoException.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/MissingFieldTypeInfoException.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/MissingFieldTypeInfoException.java new file mode 100644 index 0000000..b9f6bfa --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/MissingFieldTypeInfoException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.postpass; + +public final class MissingFieldTypeInfoException extends Exception { + + private static final long serialVersionUID = 8749941961302509358L; + + private final int fieldNumber; + + public MissingFieldTypeInfoException(int fieldNumber) { + this.fieldNumber = fieldNumber; + } + + public int getFieldNumber() { + return fieldNumber; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/OptimizerPostPass.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/OptimizerPostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/OptimizerPostPass.java new file mode 100644 index 0000000..ba0b7c7 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/OptimizerPostPass.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.flink.optimizer.postpass; + +import org.apache.flink.optimizer.plan.OptimizedPlan; + +/** + * Interface for visitors that process the optimizer's plan. Typical post processing applications are schema + * finalization or the generation/parameterization of utilities for the actual data model. + */ +public interface OptimizerPostPass { + + /** + * Central post processing function. Invoked by the optimizer after the best plan has + * been determined. + * + * @param plan The plan to be post processed. + */ + void postPass(OptimizedPlan plan); +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/PostPassUtils.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/PostPassUtils.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/PostPassUtils.java new file mode 100644 index 0000000..1fc4c34 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/PostPassUtils.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.postpass; + +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.types.Key; + + +public class PostPassUtils { + + public static <X> Class<? extends Key<?>>[] getKeys(AbstractSchema<Class< ? extends X>> schema, int[] fields) throws MissingFieldTypeInfoException { + @SuppressWarnings("unchecked") + Class<? extends Key<?>>[] keyTypes = new Class[fields.length]; + + for (int i = 0; i < fields.length; i++) { + Class<? extends X> type = schema.getType(fields[i]); + if (type == null) { + throw new MissingFieldTypeInfoException(i); + } else if (Key.class.isAssignableFrom(type)) { + @SuppressWarnings("unchecked") + Class<? extends Key<?>> keyType = (Class<? extends Key<?>>) type; + keyTypes[i] = keyType; + } else { + throw new CompilerException("The field type " + type.getName() + + " cannot be used as a key because it does not implement the interface 'Key'"); + } + } + + return keyTypes; + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/RecordModelPostPass.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/RecordModelPostPass.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/RecordModelPostPass.java new file mode 100644 index 0000000..8a2d006 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/RecordModelPostPass.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.postpass; + +import org.apache.flink.api.common.operators.DualInputOperator; +import org.apache.flink.api.common.operators.GenericDataSinkBase; +import org.apache.flink.api.common.operators.Ordering; +import org.apache.flink.api.common.operators.RecordOperator; +import org.apache.flink.api.common.operators.SingleInputOperator; +import org.apache.flink.api.common.operators.base.CoGroupOperatorBase; +import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; +import org.apache.flink.api.common.operators.util.FieldList; +import org.apache.flink.api.common.typeutils.TypeSerializerFactory; +import org.apache.flink.api.common.typeutils.record.RecordComparatorFactory; +import org.apache.flink.api.common.typeutils.record.RecordPairComparatorFactory; +import org.apache.flink.api.common.typeutils.record.RecordSerializerFactory; +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.CompilerPostPassException; +import org.apache.flink.optimizer.plan.DualInputPlanNode; +import org.apache.flink.optimizer.plan.SingleInputPlanNode; +import org.apache.flink.optimizer.plan.SinkPlanNode; +import org.apache.flink.types.Key; + +/** + * Post pass implementation for the Record data model. Does only type inference and creates + * serializers and comparators. + */ +public class RecordModelPostPass extends GenericFlatTypePostPass<Class<? extends Key<?>>, SparseKeySchema> { + + // -------------------------------------------------------------------------------------------- + // Type specific methods that extract schema information + // -------------------------------------------------------------------------------------------- + + @Override + protected SparseKeySchema createEmptySchema() { + return new SparseKeySchema(); + } + + @Override + protected void getSinkSchema(SinkPlanNode sinkPlanNode, SparseKeySchema schema) throws CompilerPostPassException { + GenericDataSinkBase<?> sink = sinkPlanNode.getSinkNode().getOperator(); + Ordering partitioning = sink.getPartitionOrdering(); + Ordering sorting = sink.getLocalOrder(); + + try { + if (partitioning != null) { + addOrderingToSchema(partitioning, schema); + } + if (sorting != null) { + addOrderingToSchema(sorting, schema); + } + } catch (ConflictingFieldTypeInfoException ex) { + throw new CompilerPostPassException("Conflicting information found when adding data sink types. " + + "Probable reason is contradicting type infos for partitioning and sorting ordering."); + } + } + + @Override + protected void getSingleInputNodeSchema(SingleInputPlanNode node, SparseKeySchema schema) + throws CompilerPostPassException, ConflictingFieldTypeInfoException + { + // check that we got the right types + SingleInputOperator<?, ?, ?> contract = (SingleInputOperator<?, ?, ?>) node.getSingleInputNode().getOperator(); + if (! (contract instanceof RecordOperator)) { + throw new CompilerPostPassException("Error: Operator is not a Record based contract. Wrong compiler invokation."); + } + RecordOperator recContract = (RecordOperator) contract; + + // add the information to the schema + int[] localPositions = contract.getKeyColumns(0); + Class<? extends Key<?>>[] types = recContract.getKeyClasses(); + for (int i = 0; i < localPositions.length; i++) { + schema.addType(localPositions[i], types[i]); + } + + // this is a temporary fix, we should solve this more generic + if (contract instanceof GroupReduceOperatorBase) { + Ordering groupOrder = ((GroupReduceOperatorBase<?, ?, ?>) contract).getGroupOrder(); + if (groupOrder != null) { + addOrderingToSchema(groupOrder, schema); + } + } + } + + @Override + protected void getDualInputNodeSchema(DualInputPlanNode node, SparseKeySchema input1Schema, SparseKeySchema input2Schema) + throws CompilerPostPassException, ConflictingFieldTypeInfoException + { + // add the nodes local information. this automatically consistency checks + DualInputOperator<?, ?, ?, ?> contract = node.getTwoInputNode().getOperator(); + if (! (contract instanceof RecordOperator)) { + throw new CompilerPostPassException("Error: Operator is not a Pact Record based contract. Wrong compiler invokation."); + } + + RecordOperator recContract = (RecordOperator) contract; + int[] localPositions1 = contract.getKeyColumns(0); + int[] localPositions2 = contract.getKeyColumns(1); + Class<? extends Key<?>>[] types = recContract.getKeyClasses(); + + if (localPositions1.length != localPositions2.length) { + throw new CompilerException("Error: The keys for the first and second input have a different number of fields."); + } + + for (int i = 0; i < localPositions1.length; i++) { + input1Schema.addType(localPositions1[i], types[i]); + } + for (int i = 0; i < localPositions2.length; i++) { + input2Schema.addType(localPositions2[i], types[i]); + } + + + // this is a temporary fix, we should solve this more generic + if (contract instanceof CoGroupOperatorBase) { + Ordering groupOrder1 = ((CoGroupOperatorBase<?, ?, ?, ?>) contract).getGroupOrderForInputOne(); + Ordering groupOrder2 = ((CoGroupOperatorBase<?, ?, ?, ?>) contract).getGroupOrderForInputTwo(); + + if (groupOrder1 != null) { + addOrderingToSchema(groupOrder1, input1Schema); + } + if (groupOrder2 != null) { + addOrderingToSchema(groupOrder2, input2Schema); + } + } + } + + private void addOrderingToSchema(Ordering o, SparseKeySchema schema) throws ConflictingFieldTypeInfoException { + for (int i = 0; i < o.getNumberOfFields(); i++) { + Integer pos = o.getFieldNumber(i); + Class<? extends Key<?>> type = o.getType(i); + schema.addType(pos, type); + } + } + + // -------------------------------------------------------------------------------------------- + // Methods to create serializers and comparators + // -------------------------------------------------------------------------------------------- + + @Override + protected TypeSerializerFactory<?> createSerializer(SparseKeySchema schema) { + return RecordSerializerFactory.get(); + } + + @Override + protected RecordComparatorFactory createComparator(FieldList fields, boolean[] directions, SparseKeySchema schema) + throws MissingFieldTypeInfoException + { + int[] positions = fields.toArray(); + Class<? extends Key<?>>[] keyTypes = PostPassUtils.getKeys(schema, positions); + return new RecordComparatorFactory(positions, keyTypes, directions); + } + + @Override + protected RecordPairComparatorFactory createPairComparator(FieldList fields1, FieldList fields2, boolean[] sortDirections, + SparseKeySchema schema1, SparseKeySchema schema2) + { + return RecordPairComparatorFactory.get(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/SparseKeySchema.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/SparseKeySchema.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/SparseKeySchema.java new file mode 100644 index 0000000..e14888e --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/postpass/SparseKeySchema.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.postpass; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; +import java.util.Map.Entry; + +import org.apache.flink.types.Key; + +/** + * Class encapsulating a schema map (int column position -> column type) and a reference counter. + */ +public class SparseKeySchema extends AbstractSchema<Class<? extends Key<?>>> { + + private final Map<Integer, Class<? extends Key<?>>> schema; + + + public SparseKeySchema() { + this.schema = new HashMap<Integer, Class<? extends Key<?>>>(); + } + + // -------------------------------------------------------------------------------------------- + + @Override + public void addType(int key, Class<? extends Key<?>> type) throws ConflictingFieldTypeInfoException { + Class<? extends Key<?>> previous = this.schema.put(key, type); + if (previous != null && previous != type) { + throw new ConflictingFieldTypeInfoException(key, previous, type); + } + } + + @Override + public Class<? extends Key<?>> getType(int field) { + return this.schema.get(field); + } + + @Override + public Iterator<Entry<Integer, Class<? extends Key<?>>>> iterator() { + return this.schema.entrySet().iterator(); + } + + public int getNumTypes() { + return this.schema.size(); + } + + // -------------------------------------------------------------------------------------------- + + @Override + public int hashCode() { + return this.schema.hashCode() ^ getNumConnectionsThatContributed(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof SparseKeySchema) { + SparseKeySchema other = (SparseKeySchema) obj; + return this.schema.equals(other.schema) && + this.getNumConnectionsThatContributed() == other.getNumConnectionsThatContributed(); + } else { + return false; + } + } + + @Override + public String toString() { + return "<" + getNumConnectionsThatContributed() + "> : " + this.schema.toString(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BinaryUnionReplacer.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BinaryUnionReplacer.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BinaryUnionReplacer.java new file mode 100644 index 0000000..bd35b5c --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BinaryUnionReplacer.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.traversals; + +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.plan.BinaryUnionPlanNode; +import org.apache.flink.optimizer.plan.Channel; +import org.apache.flink.optimizer.plan.IterationPlanNode; +import org.apache.flink.optimizer.plan.NAryUnionPlanNode; +import org.apache.flink.optimizer.plan.PlanNode; +import org.apache.flink.runtime.operators.shipping.ShipStrategyType; +import org.apache.flink.runtime.operators.util.LocalStrategy; +import org.apache.flink.util.Visitor; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * A traversal that collects cascading binary unions into a single n-ary + * union operator. The exception is, when on of the union inputs is materialized, such as in the + * static-code-path-cache in iterations. + */ +public class BinaryUnionReplacer implements Visitor<PlanNode> { + + private final Set<PlanNode> seenBefore = new HashSet<PlanNode>(); + + @Override + public boolean preVisit(PlanNode visitable) { + if (this.seenBefore.add(visitable)) { + if (visitable instanceof IterationPlanNode) { + ((IterationPlanNode) visitable).acceptForStepFunction(this); + } + return true; + } else { + return false; + } + } + + @Override + public void postVisit(PlanNode visitable) { + + if (visitable instanceof BinaryUnionPlanNode) { + + final BinaryUnionPlanNode unionNode = (BinaryUnionPlanNode) visitable; + final Channel in1 = unionNode.getInput1(); + final Channel in2 = unionNode.getInput2(); + + if (!unionNode.unionsStaticAndDynamicPath()) { + + // both on static path, or both on dynamic path. we can collapse them + NAryUnionPlanNode newUnionNode; + + List<Channel> inputs = new ArrayList<Channel>(); + collect(in1, inputs); + collect(in2, inputs); + + newUnionNode = new NAryUnionPlanNode(unionNode.getOptimizerNode(), inputs, + unionNode.getGlobalProperties(), unionNode.getCumulativeCosts()); + + newUnionNode.setParallelism(unionNode.getParallelism()); + + for (Channel c : inputs) { + c.setTarget(newUnionNode); + } + + for (Channel channel : unionNode.getOutgoingChannels()) { + channel.swapUnionNodes(newUnionNode); + newUnionNode.addOutgoingChannel(channel); + } + } + else { + // union between the static and the dynamic path. we need to handle this for now + // through a special union operator + + // make sure that the first input is the cached (static) and the second input is the dynamic + if (in1.isOnDynamicPath()) { + BinaryUnionPlanNode newUnionNode = new BinaryUnionPlanNode(unionNode); + + in1.setTarget(newUnionNode); + in2.setTarget(newUnionNode); + + for (Channel channel : unionNode.getOutgoingChannels()) { + channel.swapUnionNodes(newUnionNode); + newUnionNode.addOutgoingChannel(channel); + } + } + } + } + } + + public void collect(Channel in, List<Channel> inputs) { + if (in.getSource() instanceof NAryUnionPlanNode) { + // sanity check + if (in.getShipStrategy() != ShipStrategyType.FORWARD) { + throw new CompilerException("Bug: Plan generation for Unions picked a ship strategy between binary plan operators."); + } + if (!(in.getLocalStrategy() == null || in.getLocalStrategy() == LocalStrategy.NONE)) { + throw new CompilerException("Bug: Plan generation for Unions picked a local strategy between binary plan operators."); + } + + inputs.addAll(((NAryUnionPlanNode) in.getSource()).getListOfInputs()); + } else { + // is not a collapsed union node, so we take the channel directly + inputs.add(in); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BranchesVisitor.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BranchesVisitor.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BranchesVisitor.java new file mode 100644 index 0000000..4730546 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/BranchesVisitor.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.traversals; + +import org.apache.flink.optimizer.dag.IterationNode; +import org.apache.flink.optimizer.dag.OptimizerNode; +import org.apache.flink.util.Visitor; + +/** + * This traversal of the optimizer DAG computes the information needed to track + * branches and joins in the data flow. This is important to support plans + * that are not a minimally connected DAG (Such plans are not trees, but at least one node feeds its + * output into more than one other node). + */ +public final class BranchesVisitor implements Visitor<OptimizerNode> { + + @Override + public boolean preVisit(OptimizerNode node) { + return node.getOpenBranches() == null; + } + + @Override + public void postVisit(OptimizerNode node) { + if (node instanceof IterationNode) { + ((IterationNode) node).acceptForStepFunction(this); + } + + node.computeUnclosedBranchStack(); + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/GraphCreatingVisitor.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/GraphCreatingVisitor.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/GraphCreatingVisitor.java new file mode 100644 index 0000000..160ef95 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/GraphCreatingVisitor.java @@ -0,0 +1,392 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.traversals; + +import org.apache.flink.api.common.ExecutionMode; +import org.apache.flink.api.common.InvalidProgramException; +import org.apache.flink.api.common.operators.GenericDataSinkBase; +import org.apache.flink.api.common.operators.GenericDataSourceBase; +import org.apache.flink.api.common.operators.Operator; +import org.apache.flink.api.common.operators.Union; +import org.apache.flink.api.common.operators.base.BulkIterationBase; +import org.apache.flink.api.common.operators.base.CoGroupOperatorBase; +import org.apache.flink.api.common.operators.base.CrossOperatorBase; +import org.apache.flink.api.common.operators.base.DeltaIterationBase; +import org.apache.flink.api.common.operators.base.FilterOperatorBase; +import org.apache.flink.api.common.operators.base.FlatMapOperatorBase; +import org.apache.flink.api.common.operators.base.GroupCombineOperatorBase; +import org.apache.flink.api.common.operators.base.GroupReduceOperatorBase; +import org.apache.flink.api.common.operators.base.JoinOperatorBase; +import org.apache.flink.api.common.operators.base.MapOperatorBase; +import org.apache.flink.api.common.operators.base.MapPartitionOperatorBase; +import org.apache.flink.api.common.operators.base.PartitionOperatorBase; +import org.apache.flink.api.common.operators.base.ReduceOperatorBase; +import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase; +import org.apache.flink.optimizer.CompilerException; +import org.apache.flink.optimizer.Optimizer; +import org.apache.flink.optimizer.dag.BinaryUnionNode; +import org.apache.flink.optimizer.dag.BulkIterationNode; +import org.apache.flink.optimizer.dag.BulkPartialSolutionNode; +import org.apache.flink.optimizer.dag.CoGroupNode; +import org.apache.flink.optimizer.dag.CollectorMapNode; +import org.apache.flink.optimizer.dag.CrossNode; +import org.apache.flink.optimizer.dag.DagConnection; +import org.apache.flink.optimizer.dag.DataSinkNode; +import org.apache.flink.optimizer.dag.DataSourceNode; +import org.apache.flink.optimizer.dag.FilterNode; +import org.apache.flink.optimizer.dag.FlatMapNode; +import org.apache.flink.optimizer.dag.GroupCombineNode; +import org.apache.flink.optimizer.dag.GroupReduceNode; +import org.apache.flink.optimizer.dag.JoinNode; +import org.apache.flink.optimizer.dag.MapNode; +import org.apache.flink.optimizer.dag.MapPartitionNode; +import org.apache.flink.optimizer.dag.OptimizerNode; +import org.apache.flink.optimizer.dag.PartitionNode; +import org.apache.flink.optimizer.dag.ReduceNode; +import org.apache.flink.optimizer.dag.SolutionSetNode; +import org.apache.flink.optimizer.dag.SortPartitionNode; +import org.apache.flink.optimizer.dag.WorksetIterationNode; +import org.apache.flink.optimizer.dag.WorksetNode; +import org.apache.flink.util.Visitor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * This traversal creates the optimizer DAG from a program. + * It works as a visitor that walks the program's flow in a depth-first fashion, starting from the data sinks. + * During the descend, it creates an optimizer node for each operator, respectively data source or -sink. + * During the ascend, it connects the nodes to the full graph. + */ +public class GraphCreatingVisitor implements Visitor<Operator<?>> { + + private final Map<Operator<?>, OptimizerNode> con2node; // map from the operator objects to their + // corresponding optimizer nodes + + private final List<DataSinkNode> sinks; // all data sink nodes in the optimizer plan + + private final int defaultParallelism; // the default degree of parallelism + + private final GraphCreatingVisitor parent; // reference to enclosing creator, in case of a recursive translation + + private final ExecutionMode defaultDataExchangeMode; + + private final boolean forceDOP; + + + public GraphCreatingVisitor(int defaultParallelism, ExecutionMode defaultDataExchangeMode) { + this(null, false, defaultParallelism, defaultDataExchangeMode, null); + } + + private GraphCreatingVisitor(GraphCreatingVisitor parent, boolean forceDOP, int defaultParallelism, + ExecutionMode dataExchangeMode, HashMap<Operator<?>, OptimizerNode> closure) { + if (closure == null){ + con2node = new HashMap<Operator<?>, OptimizerNode>(); + } else { + con2node = closure; + } + + this.sinks = new ArrayList<DataSinkNode>(2); + this.defaultParallelism = defaultParallelism; + this.parent = parent; + this.defaultDataExchangeMode = dataExchangeMode; + this.forceDOP = forceDOP; + } + + public List<DataSinkNode> getSinks() { + return sinks; + } + + @SuppressWarnings("deprecation") + @Override + public boolean preVisit(Operator<?> c) { + // check if we have been here before + if (this.con2node.containsKey(c)) { + return false; + } + + final OptimizerNode n; + + // create a node for the operator (or sink or source) if we have not been here before + if (c instanceof GenericDataSinkBase) { + DataSinkNode dsn = new DataSinkNode((GenericDataSinkBase<?>) c); + this.sinks.add(dsn); + n = dsn; + } + else if (c instanceof GenericDataSourceBase) { + n = new DataSourceNode((GenericDataSourceBase<?, ?>) c); + } + else if (c instanceof MapOperatorBase) { + n = new MapNode((MapOperatorBase<?, ?, ?>) c); + } + else if (c instanceof MapPartitionOperatorBase) { + n = new MapPartitionNode((MapPartitionOperatorBase<?, ?, ?>) c); + } + else if (c instanceof org.apache.flink.api.common.operators.base.CollectorMapOperatorBase) { + n = new CollectorMapNode((org.apache.flink.api.common.operators.base.CollectorMapOperatorBase<?, ?, ?>) c); + } + else if (c instanceof FlatMapOperatorBase) { + n = new FlatMapNode((FlatMapOperatorBase<?, ?, ?>) c); + } + else if (c instanceof FilterOperatorBase) { + n = new FilterNode((FilterOperatorBase<?, ?>) c); + } + else if (c instanceof ReduceOperatorBase) { + n = new ReduceNode((ReduceOperatorBase<?, ?>) c); + } + else if (c instanceof GroupCombineOperatorBase) { + n = new GroupCombineNode((GroupCombineOperatorBase<?, ?, ?>) c); + } + else if (c instanceof GroupReduceOperatorBase) { + n = new GroupReduceNode((GroupReduceOperatorBase<?, ?, ?>) c); + } + else if (c instanceof JoinOperatorBase) { + n = new JoinNode((JoinOperatorBase<?, ?, ?, ?>) c); + } + else if (c instanceof CoGroupOperatorBase) { + n = new CoGroupNode((CoGroupOperatorBase<?, ?, ?, ?>) c); + } + else if (c instanceof CrossOperatorBase) { + n = new CrossNode((CrossOperatorBase<?, ?, ?, ?>) c); + } + else if (c instanceof BulkIterationBase) { + n = new BulkIterationNode((BulkIterationBase<?>) c); + } + else if (c instanceof DeltaIterationBase) { + n = new WorksetIterationNode((DeltaIterationBase<?, ?>) c); + } + else if (c instanceof Union){ + n = new BinaryUnionNode((Union<?>) c); + } + else if (c instanceof PartitionOperatorBase) { + n = new PartitionNode((PartitionOperatorBase<?>) c); + } + else if (c instanceof SortPartitionOperatorBase) { + n = new SortPartitionNode((SortPartitionOperatorBase<?>) c); + } + else if (c instanceof BulkIterationBase.PartialSolutionPlaceHolder) { + if (this.parent == null) { + throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations."); + } + + final BulkIterationBase.PartialSolutionPlaceHolder<?> holder = (BulkIterationBase.PartialSolutionPlaceHolder<?>) c; + final BulkIterationBase<?> enclosingIteration = holder.getContainingBulkIteration(); + final BulkIterationNode containingIterationNode = + (BulkIterationNode) this.parent.con2node.get(enclosingIteration); + + // catch this for the recursive translation of step functions + BulkPartialSolutionNode p = new BulkPartialSolutionNode(holder, containingIterationNode); + p.setDegreeOfParallelism(containingIterationNode.getParallelism()); + n = p; + } + else if (c instanceof DeltaIterationBase.WorksetPlaceHolder) { + if (this.parent == null) { + throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations."); + } + + final DeltaIterationBase.WorksetPlaceHolder<?> holder = (DeltaIterationBase.WorksetPlaceHolder<?>) c; + final DeltaIterationBase<?, ?> enclosingIteration = holder.getContainingWorksetIteration(); + final WorksetIterationNode containingIterationNode = + (WorksetIterationNode) this.parent.con2node.get(enclosingIteration); + + // catch this for the recursive translation of step functions + WorksetNode p = new WorksetNode(holder, containingIterationNode); + p.setDegreeOfParallelism(containingIterationNode.getParallelism()); + n = p; + } + else if (c instanceof DeltaIterationBase.SolutionSetPlaceHolder) { + if (this.parent == null) { + throw new InvalidProgramException("It is currently not supported to create data sinks inside iterations."); + } + + final DeltaIterationBase.SolutionSetPlaceHolder<?> holder = (DeltaIterationBase.SolutionSetPlaceHolder<?>) c; + final DeltaIterationBase<?, ?> enclosingIteration = holder.getContainingWorksetIteration(); + final WorksetIterationNode containingIterationNode = + (WorksetIterationNode) this.parent.con2node.get(enclosingIteration); + + // catch this for the recursive translation of step functions + SolutionSetNode p = new SolutionSetNode(holder, containingIterationNode); + p.setDegreeOfParallelism(containingIterationNode.getParallelism()); + n = p; + } + else { + throw new IllegalArgumentException("Unknown operator type: " + c); + } + + this.con2node.put(c, n); + + // set the parallelism only if it has not been set before. some nodes have a fixed DOP, such as the + // key-less reducer (all-reduce) + if (n.getParallelism() < 1) { + // set the degree of parallelism + int par = c.getDegreeOfParallelism(); + if (par > 0) { + if (this.forceDOP && par != this.defaultParallelism) { + par = this.defaultParallelism; + Optimizer.LOG.warn("The parallelism of nested dataflows (such as step functions in iterations) is " + + "currently fixed to the parallelism of the surrounding operator (the iteration)."); + } + } else { + par = this.defaultParallelism; + } + n.setDegreeOfParallelism(par); + } + + return true; + } + + @Override + public void postVisit(Operator<?> c) { + + OptimizerNode n = this.con2node.get(c); + + // first connect to the predecessors + n.setInput(this.con2node, this.defaultDataExchangeMode); + n.setBroadcastInputs(this.con2node, this.defaultDataExchangeMode); + + // if the node represents a bulk iteration, we recursively translate the data flow now + if (n instanceof BulkIterationNode) { + final BulkIterationNode iterNode = (BulkIterationNode) n; + final BulkIterationBase<?> iter = iterNode.getIterationContract(); + + // pass a copy of the no iterative part into the iteration translation, + // in case the iteration references its closure + HashMap<Operator<?>, OptimizerNode> closure = new HashMap<Operator<?>, OptimizerNode>(con2node); + + // first, recursively build the data flow for the step function + final GraphCreatingVisitor recursiveCreator = new GraphCreatingVisitor(this, true, + iterNode.getParallelism(), defaultDataExchangeMode, closure); + + BulkPartialSolutionNode partialSolution; + + iter.getNextPartialSolution().accept(recursiveCreator); + + partialSolution = (BulkPartialSolutionNode) recursiveCreator.con2node.get(iter.getPartialSolution()); + OptimizerNode rootOfStepFunction = recursiveCreator.con2node.get(iter.getNextPartialSolution()); + if (partialSolution == null) { + throw new CompilerException("Error: The step functions result does not depend on the partial solution."); + } + + + OptimizerNode terminationCriterion = null; + + if (iter.getTerminationCriterion() != null) { + terminationCriterion = recursiveCreator.con2node.get(iter.getTerminationCriterion()); + + // no intermediate node yet, traverse from the termination criterion to build the missing parts + if (terminationCriterion == null) { + iter.getTerminationCriterion().accept(recursiveCreator); + terminationCriterion = recursiveCreator.con2node.get(iter.getTerminationCriterion()); + } + } + + iterNode.setPartialSolution(partialSolution); + iterNode.setNextPartialSolution(rootOfStepFunction, terminationCriterion); + + // go over the contained data flow and mark the dynamic path nodes + StaticDynamicPathIdentifier identifier = new StaticDynamicPathIdentifier(iterNode.getCostWeight()); + iterNode.acceptForStepFunction(identifier); + } + else if (n instanceof WorksetIterationNode) { + final WorksetIterationNode iterNode = (WorksetIterationNode) n; + final DeltaIterationBase<?, ?> iter = iterNode.getIterationContract(); + + // we need to ensure that both the next-workset and the solution-set-delta depend on the workset. + // One check is for free during the translation, we do the other check here as a pre-condition + { + StepFunctionValidator wsf = new StepFunctionValidator(); + iter.getNextWorkset().accept(wsf); + if (!wsf.hasFoundWorkset()) { + throw new CompilerException("In the given program, the next workset does not depend on the workset. " + + "This is a prerequisite in delta iterations."); + } + } + + // calculate the closure of the anonymous function + HashMap<Operator<?>, OptimizerNode> closure = new HashMap<Operator<?>, OptimizerNode>(con2node); + + // first, recursively build the data flow for the step function + final GraphCreatingVisitor recursiveCreator = new GraphCreatingVisitor( + this, true, iterNode.getParallelism(), defaultDataExchangeMode, closure); + + // descend from the solution set delta. check that it depends on both the workset + // and the solution set. If it does depend on both, this descend should create both nodes + iter.getSolutionSetDelta().accept(recursiveCreator); + + final WorksetNode worksetNode = (WorksetNode) recursiveCreator.con2node.get(iter.getWorkset()); + + if (worksetNode == null) { + throw new CompilerException("In the given program, the solution set delta does not depend on the workset." + + "This is a prerequisite in delta iterations."); + } + + iter.getNextWorkset().accept(recursiveCreator); + + SolutionSetNode solutionSetNode = (SolutionSetNode) recursiveCreator.con2node.get(iter.getSolutionSet()); + + if (solutionSetNode == null || solutionSetNode.getOutgoingConnections() == null || solutionSetNode.getOutgoingConnections().isEmpty()) { + solutionSetNode = new SolutionSetNode((DeltaIterationBase.SolutionSetPlaceHolder<?>) iter.getSolutionSet(), iterNode); + } + else { + for (DagConnection conn : solutionSetNode.getOutgoingConnections()) { + OptimizerNode successor = conn.getTarget(); + + if (successor.getClass() == JoinNode.class) { + // find out which input to the match the solution set is + JoinNode mn = (JoinNode) successor; + if (mn.getFirstPredecessorNode() == solutionSetNode) { + mn.makeJoinWithSolutionSet(0); + } else if (mn.getSecondPredecessorNode() == solutionSetNode) { + mn.makeJoinWithSolutionSet(1); + } else { + throw new CompilerException(); + } + } + else if (successor.getClass() == CoGroupNode.class) { + CoGroupNode cg = (CoGroupNode) successor; + if (cg.getFirstPredecessorNode() == solutionSetNode) { + cg.makeCoGroupWithSolutionSet(0); + } else if (cg.getSecondPredecessorNode() == solutionSetNode) { + cg.makeCoGroupWithSolutionSet(1); + } else { + throw new CompilerException(); + } + } + else { + throw new InvalidProgramException( + "Error: The only operations allowed on the solution set are Join and CoGroup."); + } + } + } + + final OptimizerNode nextWorksetNode = recursiveCreator.con2node.get(iter.getNextWorkset()); + final OptimizerNode solutionSetDeltaNode = recursiveCreator.con2node.get(iter.getSolutionSetDelta()); + + // set the step function nodes to the iteration node + iterNode.setPartialSolution(solutionSetNode, worksetNode); + iterNode.setNextPartialSolution(solutionSetDeltaNode, nextWorksetNode, defaultDataExchangeMode); + + // go over the contained data flow and mark the dynamic path nodes + StaticDynamicPathIdentifier pathIdentifier = new StaticDynamicPathIdentifier(iterNode.getCostWeight()); + iterNode.acceptForStepFunction(pathIdentifier); + } + } +} http://git-wip-us.apache.org/repos/asf/flink/blob/633b0d6a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/IdAndEstimatesVisitor.java ---------------------------------------------------------------------- diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/IdAndEstimatesVisitor.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/IdAndEstimatesVisitor.java new file mode 100644 index 0000000..b5c09e5 --- /dev/null +++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/traversals/IdAndEstimatesVisitor.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.optimizer.traversals; + +import org.apache.flink.optimizer.DataStatistics; +import org.apache.flink.optimizer.dag.DagConnection; +import org.apache.flink.optimizer.dag.IterationNode; +import org.apache.flink.optimizer.dag.OptimizerNode; +import org.apache.flink.util.Visitor; + +/** + * This traversal of the optimizer DAG assigns IDs to each node (in a pre-order fashion), + * and calls each node to compute its estimates. The latter happens in the postVisit function, + * where it is guaranteed that all predecessors have computed their estimates. + */ +public class IdAndEstimatesVisitor implements Visitor<OptimizerNode> { + + private final DataStatistics statistics; + + private int id = 1; + + public IdAndEstimatesVisitor(DataStatistics statistics) { + this.statistics = statistics; + } + + @Override + public boolean preVisit(OptimizerNode visitable) { + return visitable.getId() == -1; + } + + @Override + public void postVisit(OptimizerNode visitable) { + // the node ids + visitable.initId(this.id++); + + // connections need to figure out their maximum path depths + for (DagConnection conn : visitable.getIncomingConnections()) { + conn.initMaxDepth(); + } + for (DagConnection conn : visitable.getBroadcastConnections()) { + conn.initMaxDepth(); + } + + // the estimates + visitable.computeOutputEstimates(this.statistics); + + // if required, recurse into the step function + if (visitable instanceof IterationNode) { + ((IterationNode) visitable).acceptForStepFunction(this); + } + } +}