http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTransformTranslators.java new file mode 100644 index 0000000..8f64730 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTransformTranslators.java @@ -0,0 +1,593 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation; + +import com.dataartisans.flink.dataflow.io.ConsoleIO; +import com.dataartisans.flink.dataflow.translation.functions.FlinkCoGroupKeyedListAggregator; +import com.dataartisans.flink.dataflow.translation.functions.FlinkCreateFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkDoFnFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkKeyedListAggregationFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkMultiOutputDoFnFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkMultiOutputPruningFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkPartialReduceFunction; +import com.dataartisans.flink.dataflow.translation.functions.FlinkReduceFunction; +import com.dataartisans.flink.dataflow.translation.functions.UnionCoder; +import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; +import com.dataartisans.flink.dataflow.translation.types.KvCoderTypeInformation; +import com.dataartisans.flink.dataflow.translation.wrappers.SinkOutputFormat; +import com.dataartisans.flink.dataflow.translation.wrappers.SourceInputFormat; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.AvroIO; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.View; +import com.google.cloud.dataflow.sdk.transforms.Write; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; +import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.io.AvroInputFormat; +import org.apache.flink.api.java.io.AvroOutputFormat; +import org.apache.flink.api.java.io.TextInputFormat; +import org.apache.flink.api.java.operators.CoGroupOperator; +import org.apache.flink.api.java.operators.DataSink; +import org.apache.flink.api.java.operators.DataSource; +import org.apache.flink.api.java.operators.FlatMapOperator; +import org.apache.flink.api.java.operators.GroupCombineOperator; +import org.apache.flink.api.java.operators.GroupReduceOperator; +import org.apache.flink.api.java.operators.Grouping; +import org.apache.flink.api.java.operators.Keys; +import org.apache.flink.api.java.operators.MapPartitionOperator; +import org.apache.flink.api.java.operators.UnsortedGrouping; +import org.apache.flink.core.fs.Path; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Translators for transforming + * Dataflow {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s to + * Flink {@link org.apache.flink.api.java.DataSet}s + */ +public class FlinkBatchTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map<Class<? extends PTransform>, FlinkBatchPipelineTranslator.BatchTransformTranslator> TRANSLATORS = new HashMap<>(); + + // register the known translators + static { + TRANSLATORS.put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslatorBatch()); + + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslatorBatch()); + // we don't need this because we translate the Combine.PerKey directly + //TRANSLATORS.put(Combine.GroupedValues.class, new CombineGroupedValuesTranslator()); + + TRANSLATORS.put(Create.Values.class, new CreateTranslatorBatch()); + + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslatorBatch()); + + TRANSLATORS.put(GroupByKey.GroupByKeyOnly.class, new GroupByKeyOnlyTranslatorBatch()); + // TODO we're currently ignoring windows here but that has to change in the future + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslatorBatch()); + + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslatorBatch()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslatorBatch()); + + TRANSLATORS.put(CoGroupByKey.class, new CoGroupByKeyTranslatorBatch()); + + TRANSLATORS.put(AvroIO.Read.Bound.class, new AvroIOReadTranslatorBatch()); + TRANSLATORS.put(AvroIO.Write.Bound.class, new AvroIOWriteTranslatorBatch()); + + TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslatorBatch()); + TRANSLATORS.put(Write.Bound.class, new WriteSinkTranslatorBatch()); + + TRANSLATORS.put(TextIO.Read.Bound.class, new TextIOReadTranslatorBatch()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteTranslatorBatch()); + + // Flink-specific + TRANSLATORS.put(ConsoleIO.Write.Bound.class, new ConsoleIOWriteTranslatorBatch()); + + } + + + public static FlinkBatchPipelineTranslator.BatchTransformTranslator<?> getTranslator(PTransform<?, ?> transform) { + return TRANSLATORS.get(transform.getClass()); + } + + private static class ReadSourceTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Read.Bounded<T>> { + + @Override + public void translateNode(Read.Bounded<T> transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + BoundedSource<T> source = transform.getSource(); + PCollection<T> output = context.getOutput(transform); + Coder<T> coder = output.getCoder(); + + TypeInformation<T> typeInformation = context.getTypeInfo(output); + + DataSource<T> dataSource = new DataSource<>(context.getExecutionEnvironment(), new SourceInputFormat<>(source, context.getPipelineOptions(), coder), typeInformation, name); + + context.setOutputDataSet(output, dataSource); + } + } + + private static class AvroIOReadTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<AvroIO.Read.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOReadTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Read.Bound<T> transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); +// Schema schema = transform.getSchema(); + PValue output = context.getOutput(transform); + + TypeInformation<T> typeInformation = context.getTypeInfo(output); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class<T> extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class<T> avroType = (Class<T>) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSource<T> source = new DataSource<>(context.getExecutionEnvironment(), + new AvroInputFormat<>(new Path(path), extractedAvroType), + typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class AvroIOWriteTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<AvroIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(AvroIOWriteTranslatorBatch.class); + + @Override + public void translateNode(AvroIO.Write.Bound<T> transform, FlinkBatchTranslationContext context) { + DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", + filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + // This is super hacky, but unfortunately we cannot get the type otherwise + Class<T> extractedAvroType; + try { + Field typeField = transform.getClass().getDeclaredField("type"); + typeField.setAccessible(true); + @SuppressWarnings("unchecked") + Class<T> avroType = (Class<T>) typeField.get(transform); + extractedAvroType = avroType; + } catch (NoSuchFieldException | IllegalAccessException e) { + // we know that the field is there and it is accessible + throw new RuntimeException("Could not access type from AvroIO.Bound", e); + } + + DataSink<T> dataSink = inputDataSet.output(new AvroOutputFormat<>(new Path + (filenamePrefix), extractedAvroType)); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class TextIOReadTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator<TextIO.Read.Bound<String>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOReadTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Read.Bound<String> transform, FlinkBatchTranslationContext context) { + String path = transform.getFilepattern(); + String name = transform.getName(); + + TextIO.CompressionType compressionType = transform.getCompressionType(); + boolean needsValidation = transform.needsValidation(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.CompressionType not yet supported. Is: {}.", compressionType); + LOG.warn("Translation of TextIO.Read.needsValidation not yet supported. Is: {}.", needsValidation); + + PValue output = context.getOutput(transform); + + TypeInformation<String> typeInformation = context.getTypeInfo(output); + DataSource<String> source = new DataSource<>(context.getExecutionEnvironment(), new TextInputFormat(new Path(path)), typeInformation, name); + + context.setOutputDataSet(output, source); + } + } + + private static class TextIOWriteTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<TextIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteTranslatorBatch.class); + + @Override + public void translateNode(TextIO.Write.Bound<T> transform, FlinkBatchTranslationContext context) { + PValue input = context.getInput(transform); + DataSet<T> inputDataSet = context.getInputDataSet(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + //inputDataSet.print(); + DataSink<T> dataSink = inputDataSet.writeAsText(filenamePrefix); + + if (numShards > 0) { + dataSink.setParallelism(numShards); + } + } + } + + private static class ConsoleIOWriteTranslatorBatch implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ConsoleIO.Write.Bound> { + @Override + public void translateNode(ConsoleIO.Write.Bound transform, FlinkBatchTranslationContext context) { + PValue input = (PValue) context.getInput(transform); + DataSet<?> inputDataSet = context.getInputDataSet(input); + inputDataSet.printOnTaskManager(transform.getName()); + } + } + + private static class WriteSinkTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Write.Bound<T>> { + + @Override + public void translateNode(Write.Bound<T> transform, FlinkBatchTranslationContext context) { + String name = transform.getName(); + PValue input = context.getInput(transform); + DataSet<T> inputDataSet = context.getInputDataSet(input); + + inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())).name(name); + } + } + + private static class GroupByKeyOnlyTranslatorBatch<K, V> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<GroupByKey.GroupByKeyOnly<K, V>> { + + @Override + public void translateNode(GroupByKey.GroupByKeyOnly<K, V> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + /** + * Translates a GroupByKey while ignoring window assignments. This is identical to the {@link GroupByKeyOnlyTranslatorBatch} + */ + private static class GroupByKeyTranslatorBatch<K, V> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<GroupByKey<K, V>> { + + @Override + public void translateNode(GroupByKey<K, V> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); + + TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); + + GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = + new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class CombinePerKeyTranslatorBatch<K, VI, VA, VO> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Combine.PerKey<K, VI, VO>> { + + @Override + public void translateNode(Combine.PerKey<K, VI, VO> transform, FlinkBatchTranslationContext context) { + DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + @SuppressWarnings("unchecked") + Combine.KeyedCombineFn<K, VI, VA, VO> keyedCombineFn = (Combine.KeyedCombineFn<K, VI, VA, VO>) transform.getFn(); + + KvCoder<K, VI> inputCoder = (KvCoder<K, VI>) context.getInput(transform).getCoder(); + + Coder<VA> accumulatorCoder = + null; + try { + accumulatorCoder = keyedCombineFn.getAccumulatorCoder(context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + } catch (CannotProvideCoderException e) { + e.printStackTrace(); + // TODO + } + + TypeInformation<KV<K, VI>> kvCoderTypeInformation = new KvCoderTypeInformation<>(inputCoder); + TypeInformation<KV<K, VA>> partialReduceTypeInfo = new KvCoderTypeInformation<>(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)); + + Grouping<KV<K, VI>> inputGrouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); + + FlinkPartialReduceFunction<K, VI, VA> partialReduceFunction = new FlinkPartialReduceFunction<>(keyedCombineFn); + + // Partially GroupReduce the values into the intermediate format VA (combine) + GroupCombineOperator<KV<K, VI>, KV<K, VA>> groupCombine = + new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, + "GroupCombine: " + transform.getName()); + + // Reduce fully to VO + GroupReduceFunction<KV<K, VA>, KV<K, VO>> reduceFunction = new FlinkReduceFunction<>(keyedCombineFn); + + TypeInformation<KV<K, VO>> reduceTypeInfo = context.getTypeInfo(context.getOutput(transform)); + + Grouping<KV<K, VA>> intermediateGrouping = new UnsortedGrouping<>(groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); + + // Fully reduce the values and create output format VO + GroupReduceOperator<KV<K, VA>, KV<K, VO>> outputDataSet = + new GroupReduceOperator<>(intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + +// private static class CombineGroupedValuesTranslator<K, VI, VO> implements FlinkPipelineTranslator.TransformTranslator<Combine.GroupedValues<K, VI, VO>> { +// +// @Override +// public void translateNode(Combine.GroupedValues<K, VI, VO> transform, TranslationContext context) { +// DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(transform.getInput()); +// +// Combine.KeyedCombineFn<? super K, ? super VI, ?, VO> keyedCombineFn = transform.getFn(); +// +// GroupReduceFunction<KV<K, VI>, KV<K, VO>> groupReduceFunction = new FlinkCombineFunction<>(keyedCombineFn); +// +// TypeInformation<KV<K, VO>> typeInformation = context.getTypeInfo(transform.getOutput()); +// +// Grouping<KV<K, VI>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{""}, inputDataSet.getType())); +// +// GroupReduceOperator<KV<K, VI>, KV<K, VO>> outputDataSet = +// new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); +// context.setOutputDataSet(transform.getOutput(), outputDataSet); +// } +// } + + private static class ParDoBoundTranslatorBatch<IN, OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ParDo.Bound<IN, OUT>> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslatorBatch.class); + + @Override + public void translateNode(ParDo.Bound<IN, OUT> transform, FlinkBatchTranslationContext context) { + DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn<IN, OUT> doFn = transform.getFn(); + + TypeInformation<OUT> typeInformation = context.getTypeInfo(context.getOutput(transform)); + + FlinkDoFnFunction<IN, OUT> doFnWrapper = new FlinkDoFnFunction<>(doFn, context.getPipelineOptions()); + MapPartitionOperator<IN, OUT> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static class ParDoBoundMultiTranslatorBatch<IN, OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<ParDo.BoundMulti<IN, OUT>> { + private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslatorBatch.class); + + @Override + public void translateNode(ParDo.BoundMulti<IN, OUT> transform, FlinkBatchTranslationContext context) { + DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); + + final DoFn<IN, OUT> doFn = transform.getFn(); + + Map<TupleTag<?>, PCollection<?>> outputs = context.getOutput(transform).getAll(); + + Map<TupleTag<?>, Integer> outputMap = Maps.newHashMap(); + // put the main output at index 0, FlinkMultiOutputDoFnFunction also expects this + outputMap.put(transform.getMainOutputTag(), 0); + int count = 1; + for (TupleTag<?> tag: outputs.keySet()) { + if (!outputMap.containsKey(tag)) { + outputMap.put(tag, count++); + } + } + + // collect all output Coders and create a UnionCoder for our tagged outputs + List<Coder<?>> outputCoders = Lists.newArrayList(); + for (PCollection<?> coll: outputs.values()) { + outputCoders.add(coll.getCoder()); + } + + UnionCoder unionCoder = UnionCoder.of(outputCoders); + + @SuppressWarnings("unchecked") + TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder); + + @SuppressWarnings("unchecked") + FlinkMultiOutputDoFnFunction<IN, OUT> doFnWrapper = new FlinkMultiOutputDoFnFunction(doFn, context.getPipelineOptions(), outputMap); + MapPartitionOperator<IN, RawUnionValue> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); + + transformSideInputs(transform.getSideInputs(), outputDataSet, context); + + for (Map.Entry<TupleTag<?>, PCollection<?>> output: outputs.entrySet()) { + TypeInformation<Object> outputType = context.getTypeInfo(output.getValue()); + int outputTag = outputMap.get(output.getKey()); + FlinkMultiOutputPruningFunction<Object> pruningFunction = new FlinkMultiOutputPruningFunction<>(outputTag); + FlatMapOperator<RawUnionValue, Object> pruningOperator = new + FlatMapOperator<>(outputDataSet, outputType, + pruningFunction, output.getValue().getName()); + context.setOutputDataSet(output.getValue(), pruningOperator); + + } + } + } + + private static class FlattenPCollectionTranslatorBatch<T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Flatten.FlattenPCollectionList<T>> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList<T> transform, FlinkBatchTranslationContext context) { + List<PCollection<T>> allInputs = context.getInput(transform).getAll(); + DataSet<T> result = null; + for(PCollection<T> collection : allInputs) { + DataSet<T> current = context.getInputDataSet(collection); + if (result == null) { + result = current; + } else { + result = result.union(current); + } + } + context.setOutputDataSet(context.getOutput(transform), result); + } + } + + private static class CreatePCollectionViewTranslatorBatch<R, T> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<View.CreatePCollectionView<R, T>> { + @Override + public void translateNode(View.CreatePCollectionView<R, T> transform, FlinkBatchTranslationContext context) { + DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); + PCollectionView<T> input = transform.apply(null); + context.setSideInputDataSet(input, inputDataSet); + } + } + + private static class CreateTranslatorBatch<OUT> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<Create.Values<OUT>> { + + @Override + public void translateNode(Create.Values<OUT> transform, FlinkBatchTranslationContext context) { + TypeInformation<OUT> typeInformation = context.getOutputTypeInfo(); + Iterable<OUT> elements = transform.getElements(); + + // we need to serialize the elements to byte arrays, since they might contain + // elements that are not serializable by Java serialization. We deserialize them + // in the FlatMap function using the Coder. + + List<byte[]> serializedElements = Lists.newArrayList(); + Coder<OUT> coder = context.getOutput(transform).getCoder(); + for (OUT element: elements) { + ByteArrayOutputStream bao = new ByteArrayOutputStream(); + try { + coder.encode(element, bao, Coder.Context.OUTER); + serializedElements.add(bao.toByteArray()); + } catch (IOException e) { + throw new RuntimeException("Could not serialize Create elements using Coder: " + e); + } + } + + DataSet<Integer> initDataSet = context.getExecutionEnvironment().fromElements(1); + FlinkCreateFunction<Integer, OUT> flatMapFunction = new FlinkCreateFunction<>(serializedElements, coder); + FlatMapOperator<Integer, OUT> outputDataSet = new FlatMapOperator<>(initDataSet, typeInformation, flatMapFunction, transform.getName()); + + context.setOutputDataSet(context.getOutput(transform), outputDataSet); + } + } + + private static void transformSideInputs(List<PCollectionView<?>> sideInputs, + MapPartitionOperator<?, ?> outputDataSet, + FlinkBatchTranslationContext context) { + // get corresponding Flink broadcast DataSets + for(PCollectionView<?> input : sideInputs) { + DataSet<?> broadcastSet = context.getSideInputDataSet(input); + outputDataSet.withBroadcastSet(broadcastSet, input.getTagInternal().getId()); + } + } + +// Disabled because it depends on a pending pull request to the DataFlowSDK + /** + * Special composite transform translator. Only called if the CoGroup is two dimensional. + * @param <K> + */ + private static class CoGroupByKeyTranslatorBatch<K, V1, V2> implements FlinkBatchPipelineTranslator.BatchTransformTranslator<CoGroupByKey<K>> { + + @Override + public void translateNode(CoGroupByKey<K> transform, FlinkBatchTranslationContext context) { + KeyedPCollectionTuple<K> input = context.getInput(transform); + + CoGbkResultSchema schema = input.getCoGbkResultSchema(); + List<KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?>> keyedCollections = input.getKeyedCollections(); + + KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection1 = keyedCollections.get(0); + KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection2 = keyedCollections.get(1); + + TupleTag<?> tupleTag1 = taggedCollection1.getTupleTag(); + TupleTag<?> tupleTag2 = taggedCollection2.getTupleTag(); + + PCollection<? extends KV<K, ?>> collection1 = taggedCollection1.getCollection(); + PCollection<? extends KV<K, ?>> collection2 = taggedCollection2.getCollection(); + + DataSet<KV<K,V1>> inputDataSet1 = context.getInputDataSet(collection1); + DataSet<KV<K,V2>> inputDataSet2 = context.getInputDataSet(collection2); + + TypeInformation<KV<K,CoGbkResult>> typeInfo = context.getOutputTypeInfo(); + + FlinkCoGroupKeyedListAggregator<K,V1,V2> aggregator = new FlinkCoGroupKeyedListAggregator<>(schema, tupleTag1, tupleTag2); + + Keys.ExpressionKeys<KV<K,V1>> keySelector1 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet1.getType()); + Keys.ExpressionKeys<KV<K,V2>> keySelector2 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet2.getType()); + + DataSet<KV<K, CoGbkResult>> out = new CoGroupOperator<>(inputDataSet1, inputDataSet2, + keySelector1, keySelector2, + aggregator, typeInfo, null, transform.getName()); + context.setOutputDataSet(context.getOutput(transform), out); + } + } + + // -------------------------------------------------------------------------------------------- + // Miscellaneous + // -------------------------------------------------------------------------------------------- + + private FlinkBatchTransformTranslators() {} +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTranslationContext.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTranslationContext.java new file mode 100644 index 0000000..1072fa3 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkBatchTranslationContext.java @@ -0,0 +1,129 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation; + +import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; +import com.dataartisans.flink.dataflow.translation.types.KvCoderTypeInformation; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TypedPValue; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.DataSet; +import org.apache.flink.api.java.ExecutionEnvironment; +import org.apache.flink.api.java.typeutils.GenericTypeInfo; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkBatchTranslationContext { + + private final Map<PValue, DataSet<?>> dataSets; + private final Map<PCollectionView<?>, DataSet<?>> broadcastDataSets; + + private final ExecutionEnvironment env; + private final PipelineOptions options; + + private AppliedPTransform<?, ?, ?> currentTransform; + + // ------------------------------------------------------------------------ + + public FlinkBatchTranslationContext(ExecutionEnvironment env, PipelineOptions options) { + this.env = env; + this.options = options; + this.dataSets = new HashMap<>(); + this.broadcastDataSets = new HashMap<>(); + } + + // ------------------------------------------------------------------------ + + public ExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public <T> DataSet<T> getInputDataSet(PValue value) { + return (DataSet<T>) dataSets.get(value); + } + + public void setOutputDataSet(PValue value, DataSet<?> set) { + if (!dataSets.containsKey(value)) { + dataSets.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public <T> DataSet<T> getSideInputDataSet(PCollectionView<?> value) { + return (DataSet<T>) broadcastDataSets.get(value); + } + + public void setSideInputDataSet(PCollectionView<?> value, DataSet<?> set) { + if (!broadcastDataSets.containsKey(value)) { + broadcastDataSets.put(value, set); + } + } + + @SuppressWarnings("unchecked") + public <T> TypeInformation<T> getTypeInfo(PInput output) { + if (output instanceof TypedPValue) { + Coder<?> outputCoder = ((TypedPValue) output).getCoder(); + if (outputCoder instanceof KvCoder) { + return new KvCoderTypeInformation((KvCoder) outputCoder); + } else { + return new CoderTypeInformation(outputCoder); + } + } + return new GenericTypeInfo<>((Class<T>)Object.class); + } + + public <T> TypeInformation<T> getInputTypeInfo() { + return getTypeInfo(currentTransform.getInput()); + } + + public <T> TypeInformation<T> getOutputTypeInfo() { + return getTypeInfo((PValue) currentTransform.getOutput()); + } + + @SuppressWarnings("unchecked") + <I extends PInput> I getInput(PTransform<I, ?> transform) { + return (I) currentTransform.getInput(); + } + + @SuppressWarnings("unchecked") + <O extends POutput> O getOutput(PTransform<?, O> transform) { + return (O) currentTransform.getOutput(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkPipelineTranslator.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkPipelineTranslator.java index 92b9135..e5c8545 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkPipelineTranslator.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkPipelineTranslator.java @@ -7,8 +7,6 @@ * * http://www.apache.org/licenses/LICENSE-2.0 * - * - * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -18,151 +16,10 @@ package com.dataartisans.flink.dataflow.translation; import com.google.cloud.dataflow.sdk.Pipeline; -import com.google.cloud.dataflow.sdk.Pipeline.PipelineVisitor; -import com.google.cloud.dataflow.sdk.options.PipelineOptions; -import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; -import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; -import com.google.cloud.dataflow.sdk.transforms.PTransform; -import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; -import com.google.cloud.dataflow.sdk.values.PValue; -import org.apache.flink.api.java.ExecutionEnvironment; - -/** - * FlinkPipelineTranslator knows how to translate Pipeline objects into Flink Jobs. - * - * This is based on {@link com.google.cloud.dataflow.sdk.runners.DataflowPipelineTranslator} - */ -public class FlinkPipelineTranslator implements PipelineVisitor { - - private final TranslationContext context; - - private int depth = 0; - - /** - * Composite transform that we want to translate before proceeding with other transforms - */ - private PTransform<?, ?> currentCompositeTransform; - - public FlinkPipelineTranslator(ExecutionEnvironment env, PipelineOptions options) { - this.context = new TranslationContext(env, options); - } +public abstract class FlinkPipelineTranslator implements Pipeline.PipelineVisitor { public void translate(Pipeline pipeline) { pipeline.traverseTopologically(this); } - - - // -------------------------------------------------------------------------------------------- - // Pipeline Visitor Methods - // -------------------------------------------------------------------------------------------- - - private static String genSpaces(int n) { - String s = ""; - for(int i = 0; i < n; i++) { - s += "| "; - } - return s; - } - - private static String formatNodeName(TransformTreeNode node) { - return node.toString().split("@")[1] + node.getTransform(); - } - - @Override - public void enterCompositeTransform(TransformTreeNode node) { - System.out.println(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); - PTransform<?, ?> transform = node.getTransform(); - - if (transform != null && currentCompositeTransform == null) { - TransformTranslator<?> translator = FlinkTransformTranslators.getTranslator(transform); - - if (translator != null) { - currentCompositeTransform = transform; - - if (transform instanceof CoGroupByKey && node.getInput().expand().size() != 2) { - // we can only optimize CoGroupByKey for input size 2 - currentCompositeTransform = null; - } - } - } - - this.depth++; - } - - @Override - public void leaveCompositeTransform(TransformTreeNode node) { - PTransform<?, ?> transform = node.getTransform(); - - if (transform != null) { - TransformTranslator<?> translator = FlinkTransformTranslators.getTranslator(transform); - - if (currentCompositeTransform == transform) { - if (translator != null) { - System.out.println(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); - applyTransform(transform, node, translator); - currentCompositeTransform = null; - } else { - throw new IllegalStateException("Attempted to translate composite transform " + - "but no translator was found: " + currentCompositeTransform); - } - } - } - - this.depth--; - System.out.println(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); - } - - @Override - public void visitTransform(TransformTreeNode node) { - System.out.println(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); - if (currentCompositeTransform != null) { - // ignore it - return; - } - - // the transformation applied in this node - PTransform<?, ?> transform = node.getTransform(); - - // the translator to the Flink operation(s) - TransformTranslator<?> translator = FlinkTransformTranslators.getTranslator(transform); - - if (translator == null) { - System.out.println(node.getTransform().getClass()); - throw new UnsupportedOperationException("The transform " + transform + " is currently not supported."); - } - - applyTransform(transform, node, translator); - } - - @Override - public void visitValue(PValue value, TransformTreeNode producer) { - // do nothing here - } - - /** - * Utility method to define a generic variable to cast the translator and the transform to. - */ - private <T extends PTransform<?, ?>> void applyTransform(PTransform<?, ?> transform, TransformTreeNode node, TransformTranslator<?> translator) { - - @SuppressWarnings("unchecked") - T typedTransform = (T) transform; - - @SuppressWarnings("unchecked") - TransformTranslator<T> typedTranslator = (TransformTranslator<T>) translator; - - // create the applied PTransform on the context - context.setCurrentTransform(AppliedPTransform.of( - node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform)); - - typedTranslator.translateNode(typedTransform, context); - } - - /** - * A translator of a {@link PTransform}. - */ - public interface TransformTranslator<Type extends PTransform> { - - void translateNode(Type transform, TranslationContext context); - } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingPipelineTranslator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingPipelineTranslator.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingPipelineTranslator.java new file mode 100644 index 0000000..c8760c7 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingPipelineTranslator.java @@ -0,0 +1,138 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.TransformTreeNode; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; +import com.google.cloud.dataflow.sdk.values.PValue; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +public class FlinkStreamingPipelineTranslator extends FlinkPipelineTranslator { + + /** The necessary context in the case of a straming job. */ + private final FlinkStreamingTranslationContext streamingContext; + + private int depth = 0; + + /** Composite transform that we want to translate before proceeding with other transforms. */ + private PTransform<?, ?> currentCompositeTransform; + + public FlinkStreamingPipelineTranslator(StreamExecutionEnvironment env, PipelineOptions options) { + this.streamingContext = new FlinkStreamingTranslationContext(env, options); + } + + // -------------------------------------------------------------------------------------------- + // Pipeline Visitor Methods + // -------------------------------------------------------------------------------------------- + + @Override + public void enterCompositeTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "enterCompositeTransform- " + formatNodeName(node)); + + PTransform<?, ?> transform = node.getTransform(); + if (transform != null && currentCompositeTransform == null) { + + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + currentCompositeTransform = transform; + } + } + this.depth++; + } + + @Override + public void leaveCompositeTransform(TransformTreeNode node) { + PTransform<?, ?> transform = node.getTransform(); + if (transform != null && currentCompositeTransform == transform) { + + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator != null) { + System.out.println(genSpaces(this.depth) + "doingCompositeTransform- " + formatNodeName(node)); + applyStreamingTransform(transform, node, translator); + currentCompositeTransform = null; + } else { + throw new IllegalStateException("Attempted to translate composite transform " + + "but no translator was found: " + currentCompositeTransform); + } + } + this.depth--; + System.out.println(genSpaces(this.depth) + "leaveCompositeTransform- " + formatNodeName(node)); + } + + @Override + public void visitTransform(TransformTreeNode node) { + System.out.println(genSpaces(this.depth) + "visitTransform- " + formatNodeName(node)); + if (currentCompositeTransform != null) { + // ignore it + return; + } + + // get the transformation corresponding to hte node we are + // currently visiting and translate it into its Flink alternative. + + PTransform<?, ?> transform = node.getTransform(); + StreamTransformTranslator<?> translator = FlinkStreamingTransformTranslators.getTranslator(transform); + if (translator == null) { + System.out.println(node.getTransform().getClass()); + throw new UnsupportedOperationException("The transform " + transform + " is currently not supported."); + } + applyStreamingTransform(transform, node, translator); + } + + @Override + public void visitValue(PValue value, TransformTreeNode producer) { + // do nothing here + } + + private <T extends PTransform<?, ?>> void applyStreamingTransform(PTransform<?, ?> transform, TransformTreeNode node, StreamTransformTranslator<?> translator) { + if (this.streamingContext == null) { + throw new IllegalStateException("The FlinkPipelineTranslator is not yet initialized."); + } + + @SuppressWarnings("unchecked") + T typedTransform = (T) transform; + + @SuppressWarnings("unchecked") + StreamTransformTranslator<T> typedTranslator = (StreamTransformTranslator<T>) translator; + + // create the applied PTransform on the batchContext + streamingContext.setCurrentTransform(AppliedPTransform.of( + node.getFullName(), node.getInput(), node.getOutput(), (PTransform) transform)); + typedTranslator.translateNode(typedTransform, streamingContext); + } + + /** + * A translator of a {@link PTransform}. + */ + public interface StreamTransformTranslator<Type extends PTransform> { + void translateNode(Type transform, FlinkStreamingTranslationContext context); + } + + private static String genSpaces(int n) { + String s = ""; + for (int i = 0; i < n; i++) { + s += "| "; + } + return s; + } + + private static String formatNodeName(TransformTreeNode node) { + return node.toString().split("@")[1] + node.getTransform(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTransformTranslators.java new file mode 100644 index 0000000..4c8cd4b --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTransformTranslators.java @@ -0,0 +1,356 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation; + +import com.dataartisans.flink.dataflow.translation.functions.UnionCoder; +import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; +import com.dataartisans.flink.dataflow.translation.wrappers.streaming.*; +import com.dataartisans.flink.dataflow.translation.wrappers.streaming.io.UnboundedFlinkSource; +import com.dataartisans.flink.dataflow.translation.wrappers.streaming.io.UnboundedSourceWrapper; +import com.google.api.client.util.Maps; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.io.Read; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.transforms.*; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.transforms.windowing.*; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PValue; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.functions.FilterFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.core.fs.FileSystem; +import org.apache.flink.streaming.api.datastream.*; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.*; + +/** + * <p> + * Coder<?> entryCoder = pCollection.getCoder(); + * if (!(entryCoder instanceof KvCoder<?, ?>)) { + * throw new IllegalArgumentException("PCollection does not use a KvCoder"); + * } + */ +public class FlinkStreamingTransformTranslators { + + // -------------------------------------------------------------------------------------------- + // Transform Translator Registry + // -------------------------------------------------------------------------------------------- + + @SuppressWarnings("rawtypes") + private static final Map<Class<? extends PTransform>, FlinkStreamingPipelineTranslator.StreamTransformTranslator> TRANSLATORS = new HashMap<>(); + + // here you can find all the available translators. + static { + TRANSLATORS.put(Read.Unbounded.class, new UnboundedReadSourceTranslator()); + TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundStreamingTranslator()); + TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteBoundStreamingTranslator()); + TRANSLATORS.put(Window.Bound.class, new WindowBoundTranslator()); + TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslator()); + TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslator()); + TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslator()); + TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiStreamingTranslator()); + + } + + public static FlinkStreamingPipelineTranslator.StreamTransformTranslator<?> getTranslator(PTransform<?, ?> transform) { + FlinkStreamingPipelineTranslator.StreamTransformTranslator<?> translator = TRANSLATORS.get(transform.getClass()); + return translator; + } + + // -------------------------------------------------------------------------------------------- + // Transformation Implementations + // -------------------------------------------------------------------------------------------- + + private static class TextIOWriteBoundStreamingTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<TextIO.Write.Bound<T>> { + private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteBoundStreamingTranslator.class); + + @Override + public void translateNode(TextIO.Write.Bound<T> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream<WindowedValue<T>> inputDataStream = context.getInputDataStream(input); + + String filenamePrefix = transform.getFilenamePrefix(); + String filenameSuffix = transform.getFilenameSuffix(); + boolean needsValidation = transform.needsValidation(); + int numShards = transform.getNumShards(); + String shardNameTemplate = transform.getShardNameTemplate(); + + // TODO: Implement these. We need Flink support for this. + LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); + LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); + LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); + + DataStream<String> dataSink = inputDataStream.flatMap(new FlatMapFunction<WindowedValue<T>, String>() { + @Override + public void flatMap(WindowedValue<T> value, Collector<String> out) throws Exception { + out.collect(value.getValue().toString()); + } + }); + DataStreamSink<String> output = dataSink.writeAsText(filenamePrefix, FileSystem.WriteMode.OVERWRITE); + + if (numShards > 0) { + output.setParallelism(numShards); + } + } + } + + private static class UnboundedReadSourceTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Read.Unbounded<T>> { + + @Override + public void translateNode(Read.Unbounded<T> transform, FlinkStreamingTranslationContext context) { + PCollection<T> output = context.getOutput(transform); + + DataStream<WindowedValue<T>> source = null; + if (transform.getSource().getClass().equals(UnboundedFlinkSource.class)) { + UnboundedFlinkSource flinkSource = (UnboundedFlinkSource) transform.getSource(); + source = context.getExecutionEnvironment() + .addSource(flinkSource.getFlinkSource()) + .flatMap(new FlatMapFunction<String, WindowedValue<String>>() { + @Override + public void flatMap(String s, Collector<WindowedValue<String>> collector) throws Exception { + collector.collect(WindowedValue.<String>of(s, Instant.now(), GlobalWindow.INSTANCE, PaneInfo.NO_FIRING)); + } + }); + } else { + source = context.getExecutionEnvironment() + .addSource(new UnboundedSourceWrapper<>(context.getPipelineOptions(), transform)); + } + context.setOutputDataStream(output, source); + } + } + + private static class ParDoBoundStreamingTranslator<IN, OUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<ParDo.Bound<IN, OUT>> { + + @Override + public void translateNode(ParDo.Bound<IN, OUT> transform, FlinkStreamingTranslationContext context) { + PCollection<OUT> output = context.getOutput(transform); + + final WindowingStrategy<OUT, ? extends BoundedWindow> windowingStrategy = + (WindowingStrategy<OUT, ? extends BoundedWindow>) + context.getOutput(transform).getWindowingStrategy(); + + WindowedValue.WindowedValueCoder<OUT> outputStreamCoder = WindowedValue.getFullCoder(output.getCoder(), windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation<WindowedValue<OUT>> outputWindowedValueCoder = new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundWrapper<IN, OUT> doFnWrapper = new FlinkParDoBoundWrapper<>(context.getPipelineOptions(), windowingStrategy, transform.getFn()); + DataStream<WindowedValue<IN>> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator<WindowedValue<OUT>, ?> outDataStream = inputDataStream.flatMap(doFnWrapper).returns(outputWindowedValueCoder); + + context.setOutputDataStream(context.getOutput(transform), outDataStream); + } + } + + public static class WindowBoundTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Window.Bound<T>> { + + @Override + public void translateNode(Window.Bound<T> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + DataStream<WindowedValue<T>> inputDataStream = context.getInputDataStream(input); + + final WindowingStrategy<T, ? extends BoundedWindow> windowingStrategy = + (WindowingStrategy<T, ? extends BoundedWindow>) + context.getOutput(transform).getWindowingStrategy(); + + final WindowFn<T, ? extends BoundedWindow> windowFn = windowingStrategy.getWindowFn(); + + WindowedValue.WindowedValueCoder<T> outputStreamCoder = WindowedValue.getFullCoder( + context.getInput(transform).getCoder(), windowingStrategy.getWindowFn().windowCoder()); + CoderTypeInformation<WindowedValue<T>> outputWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + final FlinkParDoBoundWrapper<T, T> windowDoFnAssigner = new FlinkParDoBoundWrapper<>( + context.getPipelineOptions(), windowingStrategy, createWindowAssigner(windowFn)); + + SingleOutputStreamOperator<WindowedValue<T>, ?> windowedStream = + inputDataStream.flatMap(windowDoFnAssigner).returns(outputWindowedValueCoder); + context.setOutputDataStream(context.getOutput(transform), windowedStream); + } + + private static <T, W extends BoundedWindow> DoFn<T, T> createWindowAssigner(final WindowFn<T, W> windowFn) { + return new DoFn<T, T>() { + + @Override + public void processElement(final ProcessContext c) throws Exception { + Collection<W> windows = windowFn.assignWindows( + windowFn.new AssignContext() { + @Override + public T element() { + return c.element(); + } + + @Override + public Instant timestamp() { + return c.timestamp(); + } + + @Override + public Collection<? extends BoundedWindow> windows() { + return c.windowingInternals().windows(); + } + }); + + c.windowingInternals().outputWindowedValue( + c.element(), c.timestamp(), windows, c.pane()); + } + }; + } + } + + public static class GroupByKeyTranslator<K, V> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<GroupByKey<K, V>> { + + @Override + public void translateNode(GroupByKey<K, V> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream<WindowedValue<KV<K, V>>> inputDataStream = context.getInputDataStream(input); + KvCoder<K, V> inputKvCoder = (KvCoder<K, V>) context.getInput(transform).getCoder(); + + KeyedStream<WindowedValue<KV<K, V>>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + DataStream<WindowedValue<KV<K, Iterable<V>>>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.createForIterable(context.getPipelineOptions(), + context.getInput(transform), groupByKStream); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class CombinePerKeyTranslator<K, VIN, VACC, VOUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Combine.PerKey<K, VIN, VOUT>> { + + @Override + public void translateNode(Combine.PerKey<K, VIN, VOUT> transform, FlinkStreamingTranslationContext context) { + PValue input = context.getInput(transform); + + DataStream<WindowedValue<KV<K, VIN>>> inputDataStream = context.getInputDataStream(input); + KvCoder<K, VIN> inputKvCoder = (KvCoder<K, VIN>) context.getInput(transform).getCoder(); + KvCoder<K, VOUT> outputKvCoder = (KvCoder<K, VOUT>) context.getOutput(transform).getCoder(); + + KeyedStream<WindowedValue<KV<K, VIN>>, K> groupByKStream = FlinkGroupByKeyWrapper + .groupStreamByKey(inputDataStream, inputKvCoder); + + Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combineFn = (Combine.KeyedCombineFn<K, VIN, VACC, VOUT>) transform.getFn(); + DataStream<WindowedValue<KV<K, VOUT>>> groupedByKNWstream = + FlinkGroupAlsoByWindowWrapper.create(context.getPipelineOptions(), + context.getInput(transform), groupByKStream, combineFn, outputKvCoder); + + context.setOutputDataStream(context.getOutput(transform), groupedByKNWstream); + } + } + + public static class FlattenPCollectionTranslator<T> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<Flatten.FlattenPCollectionList<T>> { + + @Override + public void translateNode(Flatten.FlattenPCollectionList<T> transform, FlinkStreamingTranslationContext context) { + List<PCollection<T>> allInputs = context.getInput(transform).getAll(); + DataStream<T> result = null; + for (PCollection<T> collection : allInputs) { + DataStream<T> current = context.getInputDataStream(collection); + result = (result == null) ? current : result.union(current); + } + context.setOutputDataStream(context.getOutput(transform), result); + } + } + + public static class ParDoBoundMultiStreamingTranslator<IN, OUT> implements FlinkStreamingPipelineTranslator.StreamTransformTranslator<ParDo.BoundMulti<IN, OUT>> { + + private final int MAIN_TAG_INDEX = 0; + + @Override + public void translateNode(ParDo.BoundMulti<IN, OUT> transform, FlinkStreamingTranslationContext context) { + + // we assume that the transformation does not change the windowing strategy. + WindowingStrategy<?, ? extends BoundedWindow> windowingStrategy = context.getInput(transform).getWindowingStrategy(); + + Map<TupleTag<?>, PCollection<?>> outputs = context.getOutput(transform).getAll(); + Map<TupleTag<?>, Integer> tagsToLabels = transformTupleTagsToLabels( + transform.getMainOutputTag(), outputs.keySet()); + + UnionCoder intermUnionCoder = getIntermUnionCoder(outputs.values()); + WindowedValue.WindowedValueCoder<RawUnionValue> outputStreamCoder = WindowedValue.getFullCoder( + intermUnionCoder, windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<RawUnionValue>> intermWindowedValueCoder = + new CoderTypeInformation<>(outputStreamCoder); + + FlinkParDoBoundMultiWrapper<IN, OUT> doFnWrapper = new FlinkParDoBoundMultiWrapper<>( + context.getPipelineOptions(), windowingStrategy, transform.getFn(), + transform.getMainOutputTag(), tagsToLabels); + + DataStream<WindowedValue<IN>> inputDataStream = context.getInputDataStream(context.getInput(transform)); + SingleOutputStreamOperator<WindowedValue<RawUnionValue>, ?> intermDataStream = + inputDataStream.flatMap(doFnWrapper).returns(intermWindowedValueCoder); + + for (Map.Entry<TupleTag<?>, PCollection<?>> output : outputs.entrySet()) { + final int outputTag = tagsToLabels.get(output.getKey()); + + WindowedValue.WindowedValueCoder<?> coderForTag = WindowedValue.getFullCoder( + output.getValue().getCoder(), + windowingStrategy.getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<?>> windowedValueCoder = + new CoderTypeInformation(coderForTag); + + context.setOutputDataStream(output.getValue(), + intermDataStream.filter(new FilterFunction<WindowedValue<RawUnionValue>>() { + @Override + public boolean filter(WindowedValue<RawUnionValue> value) throws Exception { + return value.getValue().getUnionTag() == outputTag; + } + }).flatMap(new FlatMapFunction<WindowedValue<RawUnionValue>, WindowedValue<?>>() { + @Override + public void flatMap(WindowedValue<RawUnionValue> value, Collector<WindowedValue<?>> collector) throws Exception { + collector.collect(WindowedValue.of( + value.getValue().getValue(), + value.getTimestamp(), + value.getWindows(), + value.getPane())); + } + }).returns(windowedValueCoder)); + } + } + + private Map<TupleTag<?>, Integer> transformTupleTagsToLabels(TupleTag<?> mainTag, Set<TupleTag<?>> secondaryTags) { + Map<TupleTag<?>, Integer> tagToLabelMap = Maps.newHashMap(); + tagToLabelMap.put(mainTag, MAIN_TAG_INDEX); + int count = MAIN_TAG_INDEX + 1; + for (TupleTag<?> tag : secondaryTags) { + if (!tagToLabelMap.containsKey(tag)) { + tagToLabelMap.put(tag, count++); + } + } + return tagToLabelMap; + } + + private UnionCoder getIntermUnionCoder(Collection<PCollection<?>> taggedCollections) { + List<Coder<?>> outputCoders = Lists.newArrayList(); + for (PCollection<?> coll : taggedCollections) { + outputCoders.add(coll.getCoder()); + } + return UnionCoder.of(outputCoders); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTranslationContext.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTranslationContext.java new file mode 100644 index 0000000..83ea575 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkStreamingTranslationContext.java @@ -0,0 +1,86 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation; + +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.values.*; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +import java.util.HashMap; +import java.util.Map; + +public class FlinkStreamingTranslationContext { + + private final StreamExecutionEnvironment env; + private final PipelineOptions options; + + /** + * Keeps a mapping between the output value of the PTransform (in Dataflow) and the + * Flink Operator that produced it, after the translation of the correspondinf PTransform + * to its Flink equivalent. + * */ + private final Map<PValue, DataStream<?>> dataStreams; + + private AppliedPTransform<?, ?, ?> currentTransform; + + public FlinkStreamingTranslationContext(StreamExecutionEnvironment env, PipelineOptions options) { + this.env = env; + this.options = options; + this.dataStreams = new HashMap<>(); + } + + public StreamExecutionEnvironment getExecutionEnvironment() { + return env; + } + + public PipelineOptions getPipelineOptions() { + return options; + } + + @SuppressWarnings("unchecked") + public <T> DataStream<T> getInputDataStream(PValue value) { + return (DataStream<T>) dataStreams.get(value); + } + + public void setOutputDataStream(PValue value, DataStream<?> set) { + if (!dataStreams.containsKey(value)) { + dataStreams.put(value, set); + } + } + + /** + * Sets the AppliedPTransform which carries input/output. + * @param currentTransform + */ + public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) { + this.currentTransform = currentTransform; + } + + @SuppressWarnings("unchecked") + public <I extends PInput> I getInput(PTransform<I, ?> transform) { + I input = (I) currentTransform.getInput(); + return input; + } + + @SuppressWarnings("unchecked") + public <O extends POutput> O getOutput(PTransform<?, O> transform) { + O output = (O) currentTransform.getOutput(); + return output; + } +}