http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkTransformTranslators.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkTransformTranslators.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkTransformTranslators.java deleted file mode 100644 index c1d78c0..0000000 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/FlinkTransformTranslators.java +++ /dev/null @@ -1,594 +0,0 @@ -/* - * Copyright 2015 Data Artisans GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.dataartisans.flink.dataflow.translation; - -import com.dataartisans.flink.dataflow.io.ConsoleIO; -import com.dataartisans.flink.dataflow.translation.functions.FlinkCoGroupKeyedListAggregator; -import com.dataartisans.flink.dataflow.translation.functions.FlinkCreateFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkDoFnFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkKeyedListAggregationFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkMultiOutputDoFnFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkMultiOutputPruningFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkPartialReduceFunction; -import com.dataartisans.flink.dataflow.translation.functions.FlinkReduceFunction; -import com.dataartisans.flink.dataflow.translation.functions.UnionCoder; -import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; -import com.dataartisans.flink.dataflow.translation.types.KvCoderTypeInformation; -import com.dataartisans.flink.dataflow.translation.wrappers.SinkOutputFormat; -import com.dataartisans.flink.dataflow.translation.wrappers.SourceInputFormat; -import com.google.api.client.util.Maps; -import com.google.cloud.dataflow.sdk.coders.CannotProvideCoderException; -import com.google.cloud.dataflow.sdk.coders.Coder; -import com.google.cloud.dataflow.sdk.coders.KvCoder; -import com.google.cloud.dataflow.sdk.io.AvroIO; -import com.google.cloud.dataflow.sdk.io.BoundedSource; -import com.google.cloud.dataflow.sdk.io.Read; -import com.google.cloud.dataflow.sdk.io.TextIO; -import com.google.cloud.dataflow.sdk.transforms.Combine; -import com.google.cloud.dataflow.sdk.transforms.Create; -import com.google.cloud.dataflow.sdk.transforms.DoFn; -import com.google.cloud.dataflow.sdk.transforms.Flatten; -import com.google.cloud.dataflow.sdk.transforms.GroupByKey; -import com.google.cloud.dataflow.sdk.transforms.PTransform; -import com.google.cloud.dataflow.sdk.transforms.ParDo; -import com.google.cloud.dataflow.sdk.transforms.View; -import com.google.cloud.dataflow.sdk.transforms.Write; -import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResult; -import com.google.cloud.dataflow.sdk.transforms.join.CoGbkResultSchema; -import com.google.cloud.dataflow.sdk.transforms.join.CoGroupByKey; -import com.google.cloud.dataflow.sdk.transforms.join.KeyedPCollectionTuple; -import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; -import com.google.cloud.dataflow.sdk.values.KV; -import com.google.cloud.dataflow.sdk.values.PCollection; -import com.google.cloud.dataflow.sdk.values.PCollectionView; -import com.google.cloud.dataflow.sdk.values.PValue; -import com.google.cloud.dataflow.sdk.values.TupleTag; -import com.google.common.collect.Lists; -import org.apache.flink.api.common.functions.GroupReduceFunction; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.io.AvroInputFormat; -import org.apache.flink.api.java.io.AvroOutputFormat; -import org.apache.flink.api.java.io.TextInputFormat; -import org.apache.flink.api.java.operators.CoGroupOperator; -import org.apache.flink.api.java.operators.DataSink; -import org.apache.flink.api.java.operators.DataSource; -import org.apache.flink.api.java.operators.FlatMapOperator; -import org.apache.flink.api.java.operators.GroupCombineOperator; -import org.apache.flink.api.java.operators.GroupReduceOperator; -import org.apache.flink.api.java.operators.Grouping; -import org.apache.flink.api.java.operators.Keys; -import org.apache.flink.api.java.operators.MapPartitionOperator; -import org.apache.flink.api.java.operators.UnsortedGrouping; -import org.apache.flink.core.fs.Path; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; -import java.lang.reflect.Field; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -/** - * Translators for transforming - * Dataflow {@link com.google.cloud.dataflow.sdk.transforms.PTransform}s to - * Flink {@link org.apache.flink.api.java.DataSet}s - */ -public class FlinkTransformTranslators { - - // -------------------------------------------------------------------------------------------- - // Transform Translator Registry - // -------------------------------------------------------------------------------------------- - - @SuppressWarnings("rawtypes") - private static final Map<Class<? extends PTransform>, FlinkPipelineTranslator.TransformTranslator> TRANSLATORS = new HashMap<>(); - - // register the known translators - static { - TRANSLATORS.put(View.CreatePCollectionView.class, new CreatePCollectionViewTranslator()); - - TRANSLATORS.put(Combine.PerKey.class, new CombinePerKeyTranslator()); - // we don't need this because we translate the Combine.PerKey directly - //TRANSLATORS.put(Combine.GroupedValues.class, new CombineGroupedValuesTranslator()); - - TRANSLATORS.put(Create.Values.class, new CreateTranslator()); - - TRANSLATORS.put(Flatten.FlattenPCollectionList.class, new FlattenPCollectionTranslator()); - - TRANSLATORS.put(GroupByKey.GroupByKeyOnly.class, new GroupByKeyOnlyTranslator()); - // TODO we're currently ignoring windows here but that has to change in the future - TRANSLATORS.put(GroupByKey.class, new GroupByKeyTranslator()); - - TRANSLATORS.put(ParDo.BoundMulti.class, new ParDoBoundMultiTranslator()); - TRANSLATORS.put(ParDo.Bound.class, new ParDoBoundTranslator()); - - TRANSLATORS.put(CoGroupByKey.class, new CoGroupByKeyTranslator()); - - TRANSLATORS.put(AvroIO.Read.Bound.class, new AvroIOReadTranslator()); - TRANSLATORS.put(AvroIO.Write.Bound.class, new AvroIOWriteTranslator()); - - TRANSLATORS.put(Read.Bounded.class, new ReadSourceTranslator()); - TRANSLATORS.put(Write.Bound.class, new WriteSinkTranslator()); - - TRANSLATORS.put(TextIO.Read.Bound.class, new TextIOReadTranslator()); - TRANSLATORS.put(TextIO.Write.Bound.class, new TextIOWriteTranslator()); - - // Flink-specific - TRANSLATORS.put(ConsoleIO.Write.Bound.class, new ConsoleIOWriteTranslator()); - - } - - - public static FlinkPipelineTranslator.TransformTranslator<?> getTranslator(PTransform<?, ?> transform) { - return TRANSLATORS.get(transform.getClass()); - } - - private static class ReadSourceTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<Read.Bounded<T>> { - - @Override - public void translateNode(Read.Bounded<T> transform, TranslationContext context) { - String name = transform.getName(); - BoundedSource<T> source = transform.getSource(); - PCollection<T> output = context.getOutput(transform); - Coder<T> coder = output.getCoder(); - - TypeInformation<T> typeInformation = context.getTypeInfo(output); - - DataSource<T> dataSource = new DataSource<>(context.getExecutionEnvironment(), new SourceInputFormat<>(source, context.getPipelineOptions(), coder), typeInformation, name); - - context.setOutputDataSet(output, dataSource); - } - } - - private static class AvroIOReadTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<AvroIO.Read.Bound<T>> { - private static final Logger LOG = LoggerFactory.getLogger(AvroIOReadTranslator.class); - - @Override - public void translateNode(AvroIO.Read.Bound<T> transform, TranslationContext context) { - String path = transform.getFilepattern(); - String name = transform.getName(); -// Schema schema = transform.getSchema(); - PValue output = context.getOutput(transform); - - TypeInformation<T> typeInformation = context.getTypeInfo(output); - - // This is super hacky, but unfortunately we cannot get the type otherwise - Class<T> extractedAvroType; - try { - Field typeField = transform.getClass().getDeclaredField("type"); - typeField.setAccessible(true); - @SuppressWarnings("unchecked") - Class<T> avroType = (Class<T>) typeField.get(transform); - extractedAvroType = avroType; - } catch (NoSuchFieldException | IllegalAccessException e) { - // we know that the field is there and it is accessible - throw new RuntimeException("Could not access type from AvroIO.Bound", e); - } - - DataSource<T> source = new DataSource<>(context.getExecutionEnvironment(), - new AvroInputFormat<>(new Path(path), extractedAvroType), - typeInformation, name); - - context.setOutputDataSet(output, source); - } - } - - private static class AvroIOWriteTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<AvroIO.Write.Bound<T>> { - private static final Logger LOG = LoggerFactory.getLogger(AvroIOWriteTranslator.class); - - @Override - public void translateNode(AvroIO.Write.Bound<T> transform, TranslationContext context) { - DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); - String filenamePrefix = transform.getFilenamePrefix(); - String filenameSuffix = transform.getFilenameSuffix(); - int numShards = transform.getNumShards(); - String shardNameTemplate = transform.getShardNameTemplate(); - - // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", - filenameSuffix); - LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); - - // This is super hacky, but unfortunately we cannot get the type otherwise - Class<T> extractedAvroType; - try { - Field typeField = transform.getClass().getDeclaredField("type"); - typeField.setAccessible(true); - @SuppressWarnings("unchecked") - Class<T> avroType = (Class<T>) typeField.get(transform); - extractedAvroType = avroType; - } catch (NoSuchFieldException | IllegalAccessException e) { - // we know that the field is there and it is accessible - throw new RuntimeException("Could not access type from AvroIO.Bound", e); - } - - DataSink<T> dataSink = inputDataSet.output(new AvroOutputFormat<>(new Path - (filenamePrefix), extractedAvroType)); - - if (numShards > 0) { - dataSink.setParallelism(numShards); - } - } - } - - private static class TextIOReadTranslator implements FlinkPipelineTranslator.TransformTranslator<TextIO.Read.Bound<String>> { - private static final Logger LOG = LoggerFactory.getLogger(TextIOReadTranslator.class); - - @Override - public void translateNode(TextIO.Read.Bound<String> transform, TranslationContext context) { - String path = transform.getFilepattern(); - String name = transform.getName(); - - TextIO.CompressionType compressionType = transform.getCompressionType(); - boolean needsValidation = transform.needsValidation(); - - // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.CompressionType not yet supported. Is: {}.", compressionType); - LOG.warn("Translation of TextIO.Read.needsValidation not yet supported. Is: {}.", needsValidation); - - PValue output = context.getOutput(transform); - - TypeInformation<String> typeInformation = context.getTypeInfo(output); - - DataSource<String> source = new DataSource<>(context.getExecutionEnvironment(), new TextInputFormat(new Path(path)), typeInformation, name); - - context.setOutputDataSet(output, source); - } - } - - private static class TextIOWriteTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<TextIO.Write.Bound<T>> { - private static final Logger LOG = LoggerFactory.getLogger(TextIOWriteTranslator.class); - - @Override - public void translateNode(TextIO.Write.Bound<T> transform, TranslationContext context) { - PValue input = context.getInput(transform); - DataSet<T> inputDataSet = context.getInputDataSet(input); - - String filenamePrefix = transform.getFilenamePrefix(); - String filenameSuffix = transform.getFilenameSuffix(); - boolean needsValidation = transform.needsValidation(); - int numShards = transform.getNumShards(); - String shardNameTemplate = transform.getShardNameTemplate(); - - // TODO: Implement these. We need Flink support for this. - LOG.warn("Translation of TextIO.Write.needsValidation not yet supported. Is: {}.", needsValidation); - LOG.warn("Translation of TextIO.Write.filenameSuffix not yet supported. Is: {}.", filenameSuffix); - LOG.warn("Translation of TextIO.Write.shardNameTemplate not yet supported. Is: {}.", shardNameTemplate); - - //inputDataSet.print(); - DataSink<T> dataSink = inputDataSet.writeAsText(filenamePrefix); - - if (numShards > 0) { - dataSink.setParallelism(numShards); - } - } - } - - private static class ConsoleIOWriteTranslator implements FlinkPipelineTranslator.TransformTranslator<ConsoleIO.Write.Bound> { - @Override - public void translateNode(ConsoleIO.Write.Bound transform, TranslationContext context) { - PValue input = context.getInput(transform); - DataSet<?> inputDataSet = context.getInputDataSet(input); - inputDataSet.printOnTaskManager(transform.getName()); - } - } - - private static class WriteSinkTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<Write.Bound<T>> { - - @Override - public void translateNode(Write.Bound<T> transform, TranslationContext context) { - String name = transform.getName(); - PValue input = context.getInput(transform); - DataSet<T> inputDataSet = context.getInputDataSet(input); - - inputDataSet.output(new SinkOutputFormat<>(transform, context.getPipelineOptions())).name(name); - } - } - - private static class GroupByKeyOnlyTranslator<K, V> implements FlinkPipelineTranslator.TransformTranslator<GroupByKey.GroupByKeyOnly<K, V>> { - - @Override - public void translateNode(GroupByKey.GroupByKeyOnly<K, V> transform, TranslationContext context) { - DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); - GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); - - TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); - - Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); - - GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = - new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - } - } - - /** - * Translates a GroupByKey while ignoring window assignments. This is identical to the {@link GroupByKeyOnlyTranslator} - */ - private static class GroupByKeyTranslator<K, V> implements FlinkPipelineTranslator.TransformTranslator<GroupByKey<K, V>> { - - @Override - public void translateNode(GroupByKey<K, V> transform, TranslationContext context) { - DataSet<KV<K, V>> inputDataSet = context.getInputDataSet(context.getInput(transform)); - GroupReduceFunction<KV<K, V>, KV<K, Iterable<V>>> groupReduceFunction = new FlinkKeyedListAggregationFunction<>(); - - TypeInformation<KV<K, Iterable<V>>> typeInformation = context.getTypeInfo(context.getOutput(transform)); - - Grouping<KV<K, V>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet.getType())); - - GroupReduceOperator<KV<K, V>, KV<K, Iterable<V>>> outputDataSet = - new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); - - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - } - } - - private static class CombinePerKeyTranslator<K, VI, VA, VO> implements FlinkPipelineTranslator.TransformTranslator<Combine.PerKey<K, VI, VO>> { - - @Override - public void translateNode(Combine.PerKey<K, VI, VO> transform, TranslationContext context) { - DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(context.getInput(transform)); - - @SuppressWarnings("unchecked") - Combine.KeyedCombineFn<K, VI, VA, VO> keyedCombineFn = (Combine.KeyedCombineFn<K, VI, VA, VO>) transform.getFn(); - - KvCoder<K, VI> inputCoder = (KvCoder<K, VI>) context.getInput(transform).getCoder(); - - Coder<VA> accumulatorCoder = - null; - try { - accumulatorCoder = keyedCombineFn.getAccumulatorCoder(context.getInput(transform).getPipeline().getCoderRegistry(), inputCoder.getKeyCoder(), inputCoder.getValueCoder()); - } catch (CannotProvideCoderException e) { - e.printStackTrace(); - // TODO - } - - TypeInformation<KV<K, VI>> kvCoderTypeInformation = new KvCoderTypeInformation<>(inputCoder); - TypeInformation<KV<K, VA>> partialReduceTypeInfo = new KvCoderTypeInformation<>(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)); - - Grouping<KV<K, VI>> inputGrouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{"key"}, kvCoderTypeInformation)); - - FlinkPartialReduceFunction<K, VI, VA> partialReduceFunction = new FlinkPartialReduceFunction<>(keyedCombineFn); - - // Partially GroupReduce the values into the intermediate format VA (combine) - GroupCombineOperator<KV<K, VI>, KV<K, VA>> groupCombine = - new GroupCombineOperator<>(inputGrouping, partialReduceTypeInfo, partialReduceFunction, - "GroupCombine: " + transform.getName()); - - // Reduce fully to VO - GroupReduceFunction<KV<K, VA>, KV<K, VO>> reduceFunction = new FlinkReduceFunction<>(keyedCombineFn); - - TypeInformation<KV<K, VO>> reduceTypeInfo = context.getTypeInfo(context.getOutput(transform)); - - Grouping<KV<K, VA>> intermediateGrouping = new UnsortedGrouping<>(groupCombine, new Keys.ExpressionKeys<>(new String[]{"key"}, groupCombine.getType())); - - // Fully reduce the values and create output format VO - GroupReduceOperator<KV<K, VA>, KV<K, VO>> outputDataSet = - new GroupReduceOperator<>(intermediateGrouping, reduceTypeInfo, reduceFunction, transform.getName()); - - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - } - } - -// private static class CombineGroupedValuesTranslator<K, VI, VO> implements FlinkPipelineTranslator.TransformTranslator<Combine.GroupedValues<K, VI, VO>> { -// -// @Override -// public void translateNode(Combine.GroupedValues<K, VI, VO> transform, TranslationContext context) { -// DataSet<KV<K, VI>> inputDataSet = context.getInputDataSet(transform.getInput()); -// -// Combine.KeyedCombineFn<? super K, ? super VI, ?, VO> keyedCombineFn = transform.getFn(); -// -// GroupReduceFunction<KV<K, VI>, KV<K, VO>> groupReduceFunction = new FlinkCombineFunction<>(keyedCombineFn); -// -// TypeInformation<KV<K, VO>> typeInformation = context.getTypeInfo(transform.getOutput()); -// -// Grouping<KV<K, VI>> grouping = new UnsortedGrouping<>(inputDataSet, new Keys.ExpressionKeys<>(new String[]{""}, inputDataSet.getType())); -// -// GroupReduceOperator<KV<K, VI>, KV<K, VO>> outputDataSet = -// new GroupReduceOperator<>(grouping, typeInformation, groupReduceFunction, transform.getName()); -// context.setOutputDataSet(transform.getOutput(), outputDataSet); -// } -// } - - private static class ParDoBoundTranslator<IN, OUT> implements FlinkPipelineTranslator.TransformTranslator<ParDo.Bound<IN, OUT>> { - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundTranslator.class); - - @Override - public void translateNode(ParDo.Bound<IN, OUT> transform, TranslationContext context) { - DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); - - final DoFn<IN, OUT> doFn = transform.getFn(); - - TypeInformation<OUT> typeInformation = context.getTypeInfo(context.getOutput(transform)); - - FlinkDoFnFunction<IN, OUT> doFnWrapper = new FlinkDoFnFunction<>(doFn, context.getPipelineOptions()); - MapPartitionOperator<IN, OUT> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); - - transformSideInputs(transform.getSideInputs(), outputDataSet, context); - - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - } - } - - private static class ParDoBoundMultiTranslator<IN, OUT> implements FlinkPipelineTranslator.TransformTranslator<ParDo.BoundMulti<IN, OUT>> { - private static final Logger LOG = LoggerFactory.getLogger(ParDoBoundMultiTranslator.class); - - @Override - public void translateNode(ParDo.BoundMulti<IN, OUT> transform, TranslationContext context) { - DataSet<IN> inputDataSet = context.getInputDataSet(context.getInput(transform)); - - final DoFn<IN, OUT> doFn = transform.getFn(); - - Map<TupleTag<?>, PCollection<?>> outputs = context.getOutput(transform).getAll(); - - Map<TupleTag<?>, Integer> outputMap = Maps.newHashMap(); - // put the main output at index 0, FlinkMultiOutputDoFnFunction also expects this - outputMap.put(transform.getMainOutputTag(), 0); - int count = 1; - for (TupleTag<?> tag: outputs.keySet()) { - if (!outputMap.containsKey(tag)) { - outputMap.put(tag, count++); - } - } - - // collect all output Coders and create a UnionCoder for our tagged outputs - List<Coder<?>> outputCoders = Lists.newArrayList(); - for (PCollection<?> coll: outputs.values()) { - outputCoders.add(coll.getCoder()); - } - - UnionCoder unionCoder = UnionCoder.of(outputCoders); - - @SuppressWarnings("unchecked") - TypeInformation<RawUnionValue> typeInformation = new CoderTypeInformation<>(unionCoder); - - @SuppressWarnings("unchecked") - FlinkMultiOutputDoFnFunction<IN, OUT> doFnWrapper = new FlinkMultiOutputDoFnFunction(doFn, context.getPipelineOptions(), outputMap); - MapPartitionOperator<IN, RawUnionValue> outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, transform.getName()); - - transformSideInputs(transform.getSideInputs(), outputDataSet, context); - - for (Map.Entry<TupleTag<?>, PCollection<?>> output: outputs.entrySet()) { - TypeInformation<Object> outputType = context.getTypeInfo(output.getValue()); - int outputTag = outputMap.get(output.getKey()); - FlinkMultiOutputPruningFunction<Object> pruningFunction = new FlinkMultiOutputPruningFunction<>(outputTag); - FlatMapOperator<RawUnionValue, Object> pruningOperator = new - FlatMapOperator<>(outputDataSet, outputType, - pruningFunction, output.getValue().getName()); - context.setOutputDataSet(output.getValue(), pruningOperator); - - } - } - } - - private static class FlattenPCollectionTranslator<T> implements FlinkPipelineTranslator.TransformTranslator<Flatten.FlattenPCollectionList<T>> { - - @Override - public void translateNode(Flatten.FlattenPCollectionList<T> transform, TranslationContext context) { - List<PCollection<T>> allInputs = context.getInput(transform).getAll(); - DataSet<T> result = null; - for(PCollection<T> collection : allInputs) { - DataSet<T> current = context.getInputDataSet(collection); - if (result == null) { - result = current; - } else { - result = result.union(current); - } - } - context.setOutputDataSet(context.getOutput(transform), result); - } - } - - private static class CreatePCollectionViewTranslator<R, T> implements FlinkPipelineTranslator.TransformTranslator<View.CreatePCollectionView<R, T>> { - @Override - public void translateNode(View.CreatePCollectionView<R, T> transform, TranslationContext context) { - DataSet<T> inputDataSet = context.getInputDataSet(context.getInput(transform)); - PCollectionView<T> input = transform.apply(null); - context.setSideInputDataSet(input, inputDataSet); - } - } - - private static class CreateTranslator<OUT> implements FlinkPipelineTranslator.TransformTranslator<Create.Values<OUT>> { - - @Override - public void translateNode(Create.Values<OUT> transform, TranslationContext context) { - TypeInformation<OUT> typeInformation = context.getOutputTypeInfo(); - Iterable<OUT> elements = transform.getElements(); - - // we need to serialize the elements to byte arrays, since they might contain - // elements that are not serializable by Java serialization. We deserialize them - // in the FlatMap function using the Coder. - - List<byte[]> serializedElements = Lists.newArrayList(); - Coder<OUT> coder = context.getOutput(transform).getCoder(); - for (OUT element: elements) { - ByteArrayOutputStream bao = new ByteArrayOutputStream(); - try { - coder.encode(element, bao, Coder.Context.OUTER); - serializedElements.add(bao.toByteArray()); - } catch (IOException e) { - throw new RuntimeException("Could not serialize Create elements using Coder: " + e); - } - } - - DataSet<Integer> initDataSet = context.getExecutionEnvironment().fromElements(1); - FlinkCreateFunction<Integer, OUT> flatMapFunction = new FlinkCreateFunction<>(serializedElements, coder); - FlatMapOperator<Integer, OUT> outputDataSet = new FlatMapOperator<>(initDataSet, typeInformation, flatMapFunction, transform.getName()); - - context.setOutputDataSet(context.getOutput(transform), outputDataSet); - } - } - - private static void transformSideInputs(List<PCollectionView<?>> sideInputs, - MapPartitionOperator<?, ?> outputDataSet, - TranslationContext context) { - // get corresponding Flink broadcast DataSets - for(PCollectionView<?> input : sideInputs) { - DataSet<?> broadcastSet = context.getSideInputDataSet(input); - outputDataSet.withBroadcastSet(broadcastSet, input.getTagInternal().getId()); - } - } - -// Disabled because it depends on a pending pull request to the DataFlowSDK - /** - * Special composite transform translator. Only called if the CoGroup is two dimensional. - * @param <K> - */ - private static class CoGroupByKeyTranslator<K, V1, V2> implements FlinkPipelineTranslator.TransformTranslator<CoGroupByKey<K>> { - - @Override - public void translateNode(CoGroupByKey<K> transform, TranslationContext context) { - KeyedPCollectionTuple<K> input = context.getInput(transform); - - CoGbkResultSchema schema = input.getCoGbkResultSchema(); - List<KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?>> keyedCollections = input.getKeyedCollections(); - - KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection1 = keyedCollections.get(0); - KeyedPCollectionTuple.TaggedKeyedPCollection<K, ?> taggedCollection2 = keyedCollections.get(1); - - TupleTag<?> tupleTag1 = taggedCollection1.getTupleTag(); - TupleTag<?> tupleTag2 = taggedCollection2.getTupleTag(); - - PCollection<? extends KV<K, ?>> collection1 = taggedCollection1.getCollection(); - PCollection<? extends KV<K, ?>> collection2 = taggedCollection2.getCollection(); - - DataSet<KV<K,V1>> inputDataSet1 = context.getInputDataSet(collection1); - DataSet<KV<K,V2>> inputDataSet2 = context.getInputDataSet(collection2); - - TypeInformation<KV<K,CoGbkResult>> typeInfo = context.getOutputTypeInfo(); - - FlinkCoGroupKeyedListAggregator<K,V1,V2> aggregator = new FlinkCoGroupKeyedListAggregator<>(schema, tupleTag1, tupleTag2); - - Keys.ExpressionKeys<KV<K,V1>> keySelector1 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet1.getType()); - Keys.ExpressionKeys<KV<K,V2>> keySelector2 = new Keys.ExpressionKeys<>(new String[]{"key"}, inputDataSet2.getType()); - - DataSet<KV<K, CoGbkResult>> out = new CoGroupOperator<>(inputDataSet1, inputDataSet2, - keySelector1, keySelector2, - aggregator, typeInfo, null, transform.getName()); - context.setOutputDataSet(context.getOutput(transform), out); - } - } - - // -------------------------------------------------------------------------------------------- - // Miscellaneous - // -------------------------------------------------------------------------------------------- - - private FlinkTransformTranslators() {} -}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/TranslationContext.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/TranslationContext.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/TranslationContext.java deleted file mode 100644 index af46109..0000000 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/TranslationContext.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright 2015 Data Artisans GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.dataartisans.flink.dataflow.translation; - -import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; -import com.dataartisans.flink.dataflow.translation.types.KvCoderTypeInformation; -import com.google.cloud.dataflow.sdk.coders.Coder; -import com.google.cloud.dataflow.sdk.coders.KvCoder; -import com.google.cloud.dataflow.sdk.options.PipelineOptions; -import com.google.cloud.dataflow.sdk.transforms.AppliedPTransform; -import com.google.cloud.dataflow.sdk.transforms.PTransform; -import com.google.cloud.dataflow.sdk.values.PCollectionView; -import com.google.cloud.dataflow.sdk.values.PInput; -import com.google.cloud.dataflow.sdk.values.POutput; -import com.google.cloud.dataflow.sdk.values.PValue; -import com.google.cloud.dataflow.sdk.values.TypedPValue; -import org.apache.flink.api.common.typeinfo.TypeInformation; -import org.apache.flink.api.java.DataSet; -import org.apache.flink.api.java.ExecutionEnvironment; -import org.apache.flink.api.java.typeutils.GenericTypeInfo; - -import java.util.HashMap; -import java.util.Map; - -public class TranslationContext { - - private final Map<PValue, DataSet<?>> dataSets; - private final Map<PCollectionView<?>, DataSet<?>> broadcastDataSets; - - private final ExecutionEnvironment env; - private final PipelineOptions options; - - private AppliedPTransform<?, ?, ?> currentTransform; - - // ------------------------------------------------------------------------ - - public TranslationContext(ExecutionEnvironment env, PipelineOptions options) { - this.env = env; - this.options = options; - this.dataSets = new HashMap<>(); - this.broadcastDataSets = new HashMap<>(); - } - - // ------------------------------------------------------------------------ - - public ExecutionEnvironment getExecutionEnvironment() { - return env; - } - - public PipelineOptions getPipelineOptions() { - return options; - } - - @SuppressWarnings("unchecked") - public <T> DataSet<T> getInputDataSet(PValue value) { - return (DataSet<T>) dataSets.get(value); - } - - public void setOutputDataSet(PValue value, DataSet<?> set) { - if (!dataSets.containsKey(value)) { - dataSets.put(value, set); - } - } - - /** - * Sets the AppliedPTransform which carries input/output. - * @param currentTransform - */ - public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) { - this.currentTransform = currentTransform; - } - - @SuppressWarnings("unchecked") - public <T> DataSet<T> getSideInputDataSet(PCollectionView<?> value) { - return (DataSet<T>) broadcastDataSets.get(value); - } - - public void setSideInputDataSet(PCollectionView<?> value, DataSet<?> set) { - if (!broadcastDataSets.containsKey(value)) { - broadcastDataSets.put(value, set); - } - } - - @SuppressWarnings("unchecked") - public <T> TypeInformation<T> getTypeInfo(PInput output) { - if (output instanceof TypedPValue) { - Coder<?> outputCoder = ((TypedPValue) output).getCoder(); - if (outputCoder instanceof KvCoder) { - return new KvCoderTypeInformation((KvCoder) outputCoder); - } else { - return new CoderTypeInformation(outputCoder); - } - } - return new GenericTypeInfo<>((Class<T>)Object.class); - } - - public <T> TypeInformation<T> getInputTypeInfo() { - return getTypeInfo(currentTransform.getInput()); - } - - public <T> TypeInformation<T> getOutputTypeInfo() { - return getTypeInfo((PValue) currentTransform.getOutput()); - } - - @SuppressWarnings("unchecked") - <I extends PInput> I getInput(PTransform<I, ?> transform) { - return (I) currentTransform.getInput(); - } - - @SuppressWarnings("unchecked") - <O extends POutput> O getOutput(PTransform<?, O> transform) { - return (O) currentTransform.getOutput(); - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComparator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComparator.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComparator.java new file mode 100644 index 0000000..e433589 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComparator.java @@ -0,0 +1,216 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation.types; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemorySegment; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for + * {@link com.google.cloud.dataflow.sdk.coders.Coder}. + */ +public class CoderComparator<T> extends TypeComparator<T> { + + private Coder<T> coder; + + // We use these for internal encoding/decoding for creating copies and comparing + // serialized forms using a Coder + private transient InspectableByteArrayOutputStream buffer1; + private transient InspectableByteArrayOutputStream buffer2; + + // For storing the Reference in encoded form + private transient InspectableByteArrayOutputStream referenceBuffer; + + public CoderComparator(Coder<T> coder) { + this.coder = coder; + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + @Override + public int hash(T record) { + return record.hashCode(); + } + + @Override + public void setReference(T toCompare) { + referenceBuffer.reset(); + try { + coder.encode(toCompare, referenceBuffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not set reference " + toCompare + ": " + e); + } + } + + @Override + public boolean equalToReference(T candidate) { + try { + buffer2.reset(); + coder.encode(candidate, buffer2, Coder.Context.OUTER); + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (referenceBuffer.size() != buffer2.size()) { + return false; + } + int len = buffer2.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return false; + } + } + return true; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareToReference(TypeComparator<T> other) { + InspectableByteArrayOutputStream otherReferenceBuffer = ((CoderComparator<T>) other).referenceBuffer; + + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = otherReferenceBuffer.getBuffer(); + if (referenceBuffer.size() != otherReferenceBuffer.size()) { + return referenceBuffer.size() - otherReferenceBuffer.size(); + } + int len = referenceBuffer.size(); + for (int i = 0; i < len; i++) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } + + @Override + public int compare(T first, T second) { + try { + buffer1.reset(); + buffer2.reset(); + coder.encode(first, buffer1, Coder.Context.OUTER); + coder.encode(second, buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare: ", e); + } + } + + @Override + public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { + CoderTypeSerializer<T> serializer = new CoderTypeSerializer<>(coder); + T first = serializer.deserialize(firstSource); + T second = serializer.deserialize(secondSource); + return compare(first, second); + } + + @Override + public boolean supportsNormalizedKey() { + return true; + } + + @Override + public boolean supportsSerializationWithKeyNormalization() { + return false; + } + + @Override + public int getNormalizeKeyLen() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isNormalizedKeyPrefixOnly(int keyBytes) { + return true; + } + + @Override + public void putNormalizedKey(T record, MemorySegment target, int offset, int numBytes) { + buffer1.reset(); + try { + coder.encode(record, buffer1, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + } + final byte[] data = buffer1.getBuffer(); + final int limit = offset + numBytes; + + target.put(offset, data, 0, Math.min(numBytes, buffer1.size())); + + offset += buffer1.size(); + + while (offset < limit) { + target.put(offset++, (byte) 0); + } + } + + @Override + public void writeWithKeyNormalization(T record, DataOutputView target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean invertNormalizedKey() { + return false; + } + + @Override + public TypeComparator<T> duplicate() { + return new CoderComparator<>(coder); + } + + @Override + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; + } + + @Override + public TypeComparator[] getFlatComparators() { + return new TypeComparator[] { this.duplicate() }; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComperator.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComperator.java deleted file mode 100644 index ade826d..0000000 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderComperator.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * Copyright 2015 Data Artisans GmbH - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.dataartisans.flink.dataflow.translation.types; - -import com.google.cloud.dataflow.sdk.coders.Coder; -import org.apache.flink.api.common.typeutils.TypeComparator; -import org.apache.flink.core.memory.DataInputView; -import org.apache.flink.core.memory.DataOutputView; -import org.apache.flink.core.memory.MemorySegment; - -import java.io.IOException; -import java.io.ObjectInputStream; - -/** - * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for - * {@link com.google.cloud.dataflow.sdk.coders.Coder}. - */ -public class CoderComperator<T> extends TypeComparator<T> { - - private Coder<T> coder; - - // We use these for internal encoding/decoding for creating copies and comparing - // serialized forms using a Coder - private transient InspectableByteArrayOutputStream buffer1; - private transient InspectableByteArrayOutputStream buffer2; - - // For storing the Reference in encoded form - private transient InspectableByteArrayOutputStream referenceBuffer; - - public CoderComperator(Coder<T> coder) { - this.coder = coder; - buffer1 = new InspectableByteArrayOutputStream(); - buffer2 = new InspectableByteArrayOutputStream(); - referenceBuffer = new InspectableByteArrayOutputStream(); - } - - private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { - in.defaultReadObject(); - - buffer1 = new InspectableByteArrayOutputStream(); - buffer2 = new InspectableByteArrayOutputStream(); - referenceBuffer = new InspectableByteArrayOutputStream(); - - } - - @Override - public int hash(T record) { - return record.hashCode(); - } - - @Override - public void setReference(T toCompare) { - referenceBuffer.reset(); - try { - coder.encode(toCompare, referenceBuffer, Coder.Context.OUTER); - } catch (IOException e) { - throw new RuntimeException("Could not set reference " + toCompare + ": " + e); - } - } - - @Override - public boolean equalToReference(T candidate) { - try { - buffer2.reset(); - coder.encode(candidate, buffer2, Coder.Context.OUTER); - byte[] arr = referenceBuffer.getBuffer(); - byte[] arrOther = buffer2.getBuffer(); - if (referenceBuffer.size() != buffer2.size()) { - return false; - } - int len = buffer2.size(); - for(int i = 0; i < len; i++ ) { - if (arr[i] != arrOther[i]) { - return false; - } - } - return true; - } catch (IOException e) { - throw new RuntimeException("Could not compare reference.", e); - } - } - - @Override - public int compareToReference(TypeComparator<T> other) { - InspectableByteArrayOutputStream otherReferenceBuffer = ((CoderComperator<T>) other).referenceBuffer; - - byte[] arr = referenceBuffer.getBuffer(); - byte[] arrOther = otherReferenceBuffer.getBuffer(); - if (referenceBuffer.size() != otherReferenceBuffer.size()) { - return referenceBuffer.size() - otherReferenceBuffer.size(); - } - int len = referenceBuffer.size(); - for (int i = 0; i < len; i++) { - if (arr[i] != arrOther[i]) { - return arr[i] - arrOther[i]; - } - } - return 0; - } - - @Override - public int compare(T first, T second) { - try { - buffer1.reset(); - buffer2.reset(); - coder.encode(first, buffer1, Coder.Context.OUTER); - coder.encode(second, buffer2, Coder.Context.OUTER); - byte[] arr = buffer1.getBuffer(); - byte[] arrOther = buffer2.getBuffer(); - if (buffer1.size() != buffer2.size()) { - return buffer1.size() - buffer2.size(); - } - int len = buffer1.size(); - for(int i = 0; i < len; i++ ) { - if (arr[i] != arrOther[i]) { - return arr[i] - arrOther[i]; - } - } - return 0; - } catch (IOException e) { - throw new RuntimeException("Could not compare: ", e); - } - } - - @Override - public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { - CoderTypeSerializer<T> serializer = new CoderTypeSerializer<>(coder); - T first = serializer.deserialize(firstSource); - T second = serializer.deserialize(secondSource); - return compare(first, second); - } - - @Override - public boolean supportsNormalizedKey() { - return true; - } - - @Override - public boolean supportsSerializationWithKeyNormalization() { - return false; - } - - @Override - public int getNormalizeKeyLen() { - return Integer.MAX_VALUE; - } - - @Override - public boolean isNormalizedKeyPrefixOnly(int keyBytes) { - return true; - } - - @Override - public void putNormalizedKey(T record, MemorySegment target, int offset, int numBytes) { - buffer1.reset(); - try { - coder.encode(record, buffer1, Coder.Context.OUTER); - } catch (IOException e) { - throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); - } - final byte[] data = buffer1.getBuffer(); - final int limit = offset + numBytes; - - target.put(offset, data, 0, Math.min(numBytes, buffer1.size())); - - offset += buffer1.size(); - - while (offset < limit) { - target.put(offset++, (byte) 0); - } - } - - @Override - public void writeWithKeyNormalization(T record, DataOutputView target) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public boolean invertNormalizedKey() { - return false; - } - - @Override - public TypeComparator<T> duplicate() { - return new CoderComperator<>(coder); - } - - @Override - public int extractKeys(Object record, Object[] target, int index) { - target[index] = record; - return 1; - } - - @Override - public TypeComparator[] getFlatComparators() { - return new TypeComparator[] { this.duplicate() }; - } -} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeInformation.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeInformation.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeInformation.java index 56192cd..80e451a 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeInformation.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeInformation.java @@ -32,12 +32,12 @@ import org.apache.flink.shaded.com.google.common.base.Preconditions; */ public class CoderTypeInformation<T> extends TypeInformation<T> implements AtomicType<T> { - private Coder<T> coder; + private final Coder<T> coder; @SuppressWarnings("unchecked") public CoderTypeInformation(Coder<T> coder) { - this.coder = coder; Preconditions.checkNotNull(coder); + this.coder = coder; } @Override @@ -112,6 +112,6 @@ public class CoderTypeInformation<T> extends TypeInformation<T> implements Atomi @Override public TypeComparator<T> createComparator(boolean sortOrderAscending, ExecutionConfig executionConfig) { - return new CoderComperator<>(coder); + return new CoderComparator<>(coder); } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeSerializer.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeSerializer.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeSerializer.java index 9715477..f739397 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeSerializer.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/CoderTypeSerializer.java @@ -137,9 +137,7 @@ public class CoderTypeSerializer<T> extends TypeSerializer<T> { if (o == null || getClass() != o.getClass()) return false; CoderTypeSerializer that = (CoderTypeSerializer) o; - return coder.equals(that.coder); - } @Override http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/KvCoderComperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/KvCoderComperator.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/KvCoderComperator.java index 940dba6..815569d 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/KvCoderComperator.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/KvCoderComperator.java @@ -259,6 +259,6 @@ public class KvCoderComperator <K, V> extends TypeComparator<KV<K, V>> { @Override public TypeComparator[] getFlatComparators() { - return new TypeComparator[] {new CoderComperator<>(keyCoder)}; + return new TypeComparator[] {new CoderComparator<>(keyCoder)}; } } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/VoidCoderTypeSerializer.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/VoidCoderTypeSerializer.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/VoidCoderTypeSerializer.java index 2096e27..7ce484a 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/VoidCoderTypeSerializer.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/types/VoidCoderTypeSerializer.java @@ -109,5 +109,4 @@ public class VoidCoderTypeSerializer extends TypeSerializer<VoidCoderTypeSeriali public static VoidValue INSTANCE = new VoidValue(); } - } http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/SourceInputFormat.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/SourceInputFormat.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/SourceInputFormat.java index 8c9c59c..afb15da 100644 --- a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/SourceInputFormat.java +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/SourceInputFormat.java @@ -18,7 +18,6 @@ package com.dataartisans.flink.dataflow.translation.wrappers; import com.fasterxml.jackson.databind.ObjectMapper; -import com.google.api.client.util.Lists; import com.google.cloud.dataflow.sdk.coders.Coder; import com.google.cloud.dataflow.sdk.io.BoundedSource; import com.google.cloud.dataflow.sdk.io.Source; @@ -34,6 +33,7 @@ import org.slf4j.LoggerFactory; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.util.ArrayList; import java.util.List; /** @@ -116,7 +116,7 @@ public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / numSplits; List<? extends Source<T>> shards = initialSource.splitIntoBundles(desiredSizeBytes, options); - List<SourceInputSplit<T>> splits = Lists.newArrayList(); + List<SourceInputSplit<T>> splits = new ArrayList<SourceInputSplit<T>>(); int splitCount = 0; for (Source<T> shard: shards) { splits.add(new SourceInputSplit<>(shard, splitCount++)); http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java new file mode 100644 index 0000000..53bb177 --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkAbstractParDoWrapper.java @@ -0,0 +1,274 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation.wrappers.streaming; + +import com.dataartisans.flink.dataflow.translation.wrappers.SerializableFnAggregatorWrapper; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.repackaged.com.google.common.base.Preconditions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.transforms.windowing.WindowFn; +import com.google.cloud.dataflow.sdk.util.*; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.base.Throwables; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.accumulators.AccumulatorHelper; +import org.apache.flink.api.common.functions.RichFlatMapFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; +import org.joda.time.format.PeriodFormat; + +import java.util.Collection; + +public abstract class FlinkAbstractParDoWrapper<IN, OUTDF, OUTFL> extends RichFlatMapFunction<WindowedValue<IN>, WindowedValue<OUTFL>> { + + private final DoFn<IN, OUTDF> doFn; + private final WindowingStrategy<?, ?> windowingStrategy; + private transient PipelineOptions options; + + private DoFnProcessContext context; + + public FlinkAbstractParDoWrapper(PipelineOptions options, WindowingStrategy<?, ?> windowingStrategy, DoFn<IN, OUTDF> doFn) { + Preconditions.checkNotNull(options); + Preconditions.checkNotNull(windowingStrategy); + Preconditions.checkNotNull(doFn); + + this.doFn = doFn; + this.options = options; + this.windowingStrategy = windowingStrategy; + } + +// protected void writeObject(ObjectOutputStream out) +// throws IOException, ClassNotFoundException { +// out.defaultWriteObject(); +// ObjectMapper mapper = new ObjectMapper(); +// mapper.writeValue(out, options); +// } +// +// protected void readObject(ObjectInputStream in) +// throws IOException, ClassNotFoundException { +// in.defaultReadObject(); +// ObjectMapper mapper = new ObjectMapper(); +// options = mapper.readValue(in, PipelineOptions.class); +// } + + private void initContext(DoFn<IN, OUTDF> function, Collector<WindowedValue<OUTFL>> outCollector) { + if (this.context == null) { + this.context = new DoFnProcessContext(function, outCollector); + } + } + + @Override + public void flatMap(WindowedValue<IN> value, Collector<WindowedValue<OUTFL>> out) throws Exception { + this.initContext(doFn, out); + + // for each window the element belongs to, create a new copy here. + Collection<? extends BoundedWindow> windows = value.getWindows(); + if (windows.size() <= 1) { + processElement(value); + } else { + for (BoundedWindow window : windows) { + processElement(WindowedValue.of( + value.getValue(), value.getTimestamp(), window, value.getPane())); + } + } + } + + private void processElement(WindowedValue<IN> value) throws Exception { + this.context.setElement(value); + this.doFn.startBundle(context); + doFn.processElement(context); + this.doFn.finishBundle(context); + } + + private class DoFnProcessContext extends DoFn<IN, OUTDF>.ProcessContext { + + private final DoFn<IN, OUTDF> fn; + + protected final Collector<WindowedValue<OUTFL>> collector; + + private WindowedValue<IN> element; + + private DoFnProcessContext(DoFn<IN, OUTDF> function, Collector<WindowedValue<OUTFL>> outCollector) { + function.super(); + super.setupDelegateAggregators(); + + this.fn = function; + this.collector = outCollector; + } + + public void setElement(WindowedValue<IN> value) { + this.element = value; + } + + @Override + public IN element() { + return this.element.getValue(); + } + + @Override + public Instant timestamp() { + return this.element.getTimestamp(); + } + + @Override + public BoundedWindow window() { +// if (!(fn instanceof DoFn.RequiresWindowAccess)) { +// throw new UnsupportedOperationException( +// "window() is only available in the context of a DoFn marked as RequiresWindow."); +// } + + Collection<? extends BoundedWindow> windows = this.element.getWindows(); + if (windows.size() != 1) { + throw new IllegalArgumentException("Each element is expected to belong to 1 window. " + + "This belongs to " + windows.size() + "."); + } + return windows.iterator().next(); + } + + @Override + public PaneInfo pane() { + return this.element.getPane(); + } + + @Override + public WindowingInternals<IN, OUTDF> windowingInternals() { + return windowingInternalsHelper(element, collector); + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + throw new RuntimeException("sideInput() is not supported in Streaming mode."); + } + + @Override + public void output(OUTDF output) { + outputWithTimestamp(output, this.element.getTimestamp()); + } + + @Override + public void outputWithTimestamp(OUTDF output, Instant timestamp) { + outputWithTimestampHelper(element, output, timestamp, collector); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + sideOutputWithTimestamp(tag, output, this.element.getTimestamp()); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + sideOutputWithTimestampHelper(element, output, timestamp, collector, tag); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal(String name, Combine.CombineFn<AggInputT, ?, AggOutputT> combiner) { + Accumulator acc = getRuntimeContext().getAccumulator(name); + if (acc != null) { + AccumulatorHelper.compareAccumulatorTypes(name, + SerializableFnAggregatorWrapper.class, acc.getClass()); + return (Aggregator<AggInputT, AggOutputT>) acc; + } + + SerializableFnAggregatorWrapper<AggInputT, AggOutputT> accumulator = + new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, accumulator); + return accumulator; + } + } + + protected void checkTimestamp(WindowedValue<IN> ref, Instant timestamp) { + if (timestamp.isBefore(ref.getTimestamp().minus(doFn.getAllowedTimestampSkew()))) { + throw new IllegalArgumentException(String.format( + "Cannot output with timestamp %s. Output timestamps must be no earlier than the " + + "timestamp of the current input (%s) minus the allowed skew (%s). See the " + + "DoFn#getAllowedTimestmapSkew() Javadoc for details on changing the allowed skew.", + timestamp, ref.getTimestamp(), + PeriodFormat.getDefault().print(doFn.getAllowedTimestampSkew().toPeriod()))); + } + } + + protected <T> WindowedValue<T> makeWindowedValue( + T output, Instant timestamp, Collection<? extends BoundedWindow> windows, PaneInfo pane) { + final Instant inputTimestamp = timestamp; + final WindowFn windowFn = windowingStrategy.getWindowFn(); + + if (timestamp == null) { + timestamp = BoundedWindow.TIMESTAMP_MIN_VALUE; + } + + if (windows == null) { + try { + windows = windowFn.assignWindows(windowFn.new AssignContext() { + @Override + public Object element() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input element when none was available"); // TODO: 12/16/15 aljoscha's comment in slack + } + + @Override + public Instant timestamp() { + if (inputTimestamp == null) { + throw new UnsupportedOperationException( + "WindowFn attempted to access input timestamp when none was available"); + } + return inputTimestamp; + } + + @Override + public Collection<? extends BoundedWindow> windows() { + throw new UnsupportedOperationException( + "WindowFn attempted to access input windows when none were available"); + } + }); + } catch (Exception e) { + Throwables.propagateIfInstanceOf(e, UserCodeException.class); + throw new UserCodeException(e); + } + } + + return WindowedValue.of(output, timestamp, windows, pane); + } + + /////////// ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES ///////////////// + + public abstract void outputWithTimestampHelper( + WindowedValue<IN> inElement, + OUTDF output, + Instant timestamp, + Collector<WindowedValue<OUTFL>> outCollector); + + public abstract <T> void sideOutputWithTimestampHelper( + WindowedValue<IN> inElement, + T output, + Instant timestamp, + Collector<WindowedValue<OUTFL>> outCollector, + TupleTag<T> tag); + + public abstract WindowingInternals<IN, OUTDF> windowingInternalsHelper( + WindowedValue<IN> inElement, + Collector<WindowedValue<OUTFL>> outCollector); + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/edff0785/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java new file mode 100644 index 0000000..c52fabe --- /dev/null +++ b/runners/flink/src/main/java/com/dataartisans/flink/dataflow/translation/wrappers/streaming/FlinkGroupAlsoByWindowWrapper.java @@ -0,0 +1,601 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.dataartisans.flink.dataflow.translation.wrappers.streaming; + +import com.dataartisans.flink.dataflow.translation.types.CoderTypeInformation; +import com.dataartisans.flink.dataflow.translation.wrappers.SerializableFnAggregatorWrapper; +import com.dataartisans.flink.dataflow.translation.wrappers.streaming.state.*; +import com.google.cloud.dataflow.sdk.coders.*; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.repackaged.com.google.common.base.Preconditions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.*; +import com.google.cloud.dataflow.sdk.values.*; +import org.apache.flink.api.common.accumulators.Accumulator; +import org.apache.flink.api.common.accumulators.AccumulatorHelper; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.runtime.state.StateBackend; +import org.apache.flink.runtime.state.StateHandle; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.KeyedStream; +import org.apache.flink.streaming.api.operators.*; +import org.apache.flink.streaming.api.watermark.Watermark; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.streaming.runtime.tasks.StreamTaskState; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.util.*; + +/** + * This class is the key class implementing all the windowing/triggering logic of Google Dataflow. + * To provide full compatibility and support all the windowing/triggering combinations offered by + * Datadlow, we opted for a strategy that uses the SDK's code for doing these operations + * ({@link com.google.cloud.dataflow.sdk.util.StreamingGroupAlsoByWindowsDoFn}. + * <p> + * In a nutshell, when the execution arrives to this operator, we expect to have a stream <b>already + * grouped by key</b>. Each of the elements that enter here, registers a timer + * (see {@link TimerInternals#setTimer(TimerInternals.TimerData)} in the + * {@link FlinkGroupAlsoByWindowWrapper#activeTimers}. + * This is essentially a timestamp indicating when to trigger the computation over the window this + * element belongs to. + * <p> + * When a watermark arrives, all the registered timers are checked to see which ones are ready to + * fire (see {@link FlinkGroupAlsoByWindowWrapper#processWatermark(Watermark)}). These are deregistered from + * the {@link FlinkGroupAlsoByWindowWrapper#activeTimers} + * list, and are fed into the {@link com.google.cloud.dataflow.sdk.util.StreamingGroupAlsoByWindowsDoFn} + * for furhter processing. + */ +public class FlinkGroupAlsoByWindowWrapper<K, VIN, VACC, VOUT> + extends AbstractStreamOperator<WindowedValue<KV<K, VOUT>>> + implements OneInputStreamOperator<WindowedValue<KV<K, VIN>>, WindowedValue<KV<K, VOUT>>> { + + private static final long serialVersionUID = 1L; + + private transient PipelineOptions options; + + private transient CoderRegistry coderRegistry; + + private StreamingGroupAlsoByWindowsDoFn operator; + + private ProcessContext context; + + private final WindowingStrategy<?, ?> windowingStrategy; + + private final Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combineFn; + + private final KvCoder<K, VIN> inputKvCoder; + + /** + * State is kept <b>per-key</b>. This data structure keeps this mapping between an active key, i.e. a + * key whose elements are currently waiting to be processed, and its associated state. + */ + private Map<K, FlinkStateInternals<K>> perKeyStateInternals = new HashMap<>(); + + /** + * Timers waiting to be processed. + */ + private Map<K, Set<TimerInternals.TimerData>> activeTimers = new HashMap<>(); + + private FlinkTimerInternals timerInternals = new FlinkTimerInternals(); + + /** + * Creates an DataStream where elements are grouped in windows based on the specified windowing strategy. + * This method assumes that <b>elements are already grouped by key</b>. + * <p> + * The difference with {@link #createForIterable(PipelineOptions, PCollection, KeyedStream)} + * is that this method assumes that a combiner function is provided + * (see {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn}). + * A combiner helps at increasing the speed and, in most of the cases, reduce the per-window state. + * + * @param options the general job configuration options. + * @param input the input Dataflow {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * @param groupedStreamByKey the input stream, it is assumed to already be grouped by key. + * @param combiner the combiner to be used. + * @param outputKvCoder the type of the output values. + */ + public static <K, VIN, VACC, VOUT> DataStream<WindowedValue<KV<K, VOUT>>> create( + PipelineOptions options, + PCollection input, + KeyedStream<WindowedValue<KV<K, VIN>>, K> groupedStreamByKey, + Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combiner, + KvCoder<K, VOUT> outputKvCoder) { + + KvCoder<K, VIN> inputKvCoder = (KvCoder<K, VIN>) input.getCoder(); + FlinkGroupAlsoByWindowWrapper windower = new FlinkGroupAlsoByWindowWrapper<>(options, + input.getPipeline().getCoderRegistry(), input.getWindowingStrategy(), inputKvCoder, combiner); + + Coder<WindowedValue<KV<K, VOUT>>> windowedOutputElemCoder = WindowedValue.FullWindowedValueCoder.of( + outputKvCoder, + input.getWindowingStrategy().getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<KV<K, VOUT>>> outputTypeInfo = + new CoderTypeInformation<>(windowedOutputElemCoder); + + DataStream<WindowedValue<KV<K, VOUT>>> groupedByKeyAndWindow = groupedStreamByKey + .transform("GroupByWindowWithCombiner", + new CoderTypeInformation<>(outputKvCoder), + windower) + .returns(outputTypeInfo); + + return groupedByKeyAndWindow; + } + + /** + * Creates an DataStream where elements are grouped in windows based on the specified windowing strategy. + * This method assumes that <b>elements are already grouped by key</b>. + * <p> + * The difference with {@link #create(PipelineOptions, PCollection, KeyedStream, Combine.KeyedCombineFn, KvCoder)} + * is that this method assumes no combiner function + * (see {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn}). + * + * @param options the general job configuration options. + * @param input the input Dataflow {@link com.google.cloud.dataflow.sdk.values.PCollection}. + * @param groupedStreamByKey the input stream, it is assumed to already be grouped by key. + */ + public static <K, VIN> DataStream<WindowedValue<KV<K, Iterable<VIN>>>> createForIterable( + PipelineOptions options, + PCollection input, + KeyedStream<WindowedValue<KV<K, VIN>>, K> groupedStreamByKey) { + + KvCoder<K, VIN> inputKvCoder = (KvCoder<K, VIN>) input.getCoder(); + Coder<K> keyCoder = inputKvCoder.getKeyCoder(); + Coder<VIN> inputValueCoder = inputKvCoder.getValueCoder(); + + FlinkGroupAlsoByWindowWrapper windower = new FlinkGroupAlsoByWindowWrapper(options, + input.getPipeline().getCoderRegistry(), input.getWindowingStrategy(), inputKvCoder, null); + + Coder<Iterable<VIN>> valueIterCoder = IterableCoder.of(inputValueCoder); + KvCoder<K, Iterable<VIN>> outputElemCoder = KvCoder.of(keyCoder, valueIterCoder); + + Coder<WindowedValue<KV<K, Iterable<VIN>>>> windowedOutputElemCoder = WindowedValue.FullWindowedValueCoder.of( + outputElemCoder, + input.getWindowingStrategy().getWindowFn().windowCoder()); + + CoderTypeInformation<WindowedValue<KV<K, Iterable<VIN>>>> outputTypeInfo = + new CoderTypeInformation<>(windowedOutputElemCoder); + + DataStream<WindowedValue<KV<K, Iterable<VIN>>>> groupedByKeyAndWindow = groupedStreamByKey + .transform("GroupByWindow", + new CoderTypeInformation<>(windowedOutputElemCoder), + windower) + .returns(outputTypeInfo); + + return groupedByKeyAndWindow; + } + + public static <K, VIN, VACC, VOUT> FlinkGroupAlsoByWindowWrapper createForTesting(PipelineOptions options, + CoderRegistry registry, + WindowingStrategy<?, ?> windowingStrategy, + KvCoder<K, VIN> inputCoder, + Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combiner) { + return new FlinkGroupAlsoByWindowWrapper(options, registry, windowingStrategy, inputCoder, combiner); + } + + private FlinkGroupAlsoByWindowWrapper(PipelineOptions options, + CoderRegistry registry, + WindowingStrategy<?, ?> windowingStrategy, + KvCoder<K, VIN> inputCoder, + Combine.KeyedCombineFn<K, VIN, VACC, VOUT> combiner) { + + this.options = Preconditions.checkNotNull(options); + this.coderRegistry = Preconditions.checkNotNull(registry); + this.inputKvCoder = Preconditions.checkNotNull(inputCoder);//(KvCoder<K, VIN>) input.getCoder(); + this.combineFn = combiner; + this.windowingStrategy = Preconditions.checkNotNull(windowingStrategy);//input.getWindowingStrategy(); + this.operator = createGroupAlsoByWindowOperator(); + this.chainingStrategy = ChainingStrategy.ALWAYS; + } + + @Override + public void open() throws Exception { + super.open(); + this.context = new ProcessContext(operator, new TimestampedCollector<>(output), this.timerInternals); + + // this is to cover the case that this is the state after a recovery. + // In this case, the restoreState() has already initialized the timerInternals to a certain value. + TimerOrElement<WindowedValue<KV<K, VIN>>> element = this.timerInternals.getElement(); + if (element != null) { + if (element.isTimer()) { + throw new RuntimeException("The recovered element cannot be a Timer."); + } + K key = element.element().getValue().getKey(); + FlinkStateInternals<K> stateForKey = getStateInternalsForKey(key); + this.context.setElement(element, stateForKey); + } + } + + /** + * Create the adequate {@link com.google.cloud.dataflow.sdk.util.StreamingGroupAlsoByWindowsDoFn}, + * <b> if not already created</b>. + * If a {@link com.google.cloud.dataflow.sdk.transforms.Combine.KeyedCombineFn} was provided, then + * a function with that combiner is created, so that elements are combined as they arrive. This is + * done for speed and (in most of the cases) for reduction of the per-window state. + */ + private StreamingGroupAlsoByWindowsDoFn createGroupAlsoByWindowOperator() { + if (this.operator == null) { + if (this.combineFn == null) { + Coder<VIN> inputValueCoder = inputKvCoder.getValueCoder(); + + this.operator = StreamingGroupAlsoByWindowsDoFn.createForIterable( + this.windowingStrategy, inputValueCoder); + } else { + + Coder<K> inputKeyCoder = inputKvCoder.getKeyCoder(); + //CoderRegistry dataflowRegistry = input.getPipeline().getCoderRegistry(); + + AppliedCombineFn<K, VIN, VACC, VOUT> appliedCombineFn = AppliedCombineFn + .withInputCoder(combineFn, coderRegistry, inputKvCoder); + + this.operator = StreamingGroupAlsoByWindowsDoFn.create( + this.windowingStrategy, appliedCombineFn, inputKeyCoder); + } + } + return this.operator; + } + + + @Override + public void processElement(StreamRecord<WindowedValue<KV<K, VIN>>> element) throws Exception { + WindowedValue<KV<K, VIN>> value = element.getValue(); + TimerOrElement<WindowedValue<KV<K, VIN>>> elem = TimerOrElement.element(value); + processElementOrTimer(elem); + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + + context.setCurrentWatermark(new Instant(mark.getTimestamp())); + + Set<TimerOrElement> toFire = getTimersReadyToProcess(mark.getTimestamp()); + if (!toFire.isEmpty()) { + for (TimerOrElement timer : toFire) { + processElementOrTimer(timer); + } + } + + /** + * This is to take into account the different semantics of the Watermark in Flink and + * in Dataflow. To understand the reasoning behind the Dataflow semantics and its + * watermark holding logic, see the documentation of + * {@link WatermarkHold#addHold(ReduceFn.ProcessValueContext, boolean)} + * */ + long millis = Long.MAX_VALUE; + for (FlinkStateInternals state : perKeyStateInternals.values()) { + Instant watermarkHold = state.getWatermarkHold(); + if (watermarkHold != null && watermarkHold.getMillis() < millis) { + millis = watermarkHold.getMillis(); + } + } + + if (mark.getTimestamp() < millis) { + millis = mark.getTimestamp(); + } + + // Don't forget to re-emit the watermark for further operators down the line. + // This is critical for jobs with multiple aggregation steps. + // Imagine a job with a groupByKey() on key K1, followed by a map() that changes + // the key K1 to K2, and another groupByKey() on K2. In this case, if the watermark + // is not re-emitted, the second aggregation would never be triggered, and no result + // will be produced. + output.emitWatermark(new Watermark(millis)); + } + + private void processElementOrTimer(TimerOrElement<WindowedValue<KV<K, VIN>>> timerOrElement) throws Exception { + K key = timerOrElement.isTimer() ? + (K) timerOrElement.key() : + timerOrElement.element().getValue().getKey(); + + context.setElement(timerOrElement, getStateInternalsForKey(key)); + + operator.startBundle(context); + operator.processElement(context); + operator.finishBundle(context); + } + + private void registerActiveTimer(K key, TimerInternals.TimerData timer) { + Set<TimerInternals.TimerData> timersForKey = activeTimers.get(key); + if (timersForKey == null) { + timersForKey = new HashSet<>(); + } + timersForKey.add(timer); + activeTimers.put(key, timersForKey); + } + + private void unregisterActiveTimer(K key, TimerInternals.TimerData timer) { + Set<TimerInternals.TimerData> timersForKey = activeTimers.get(key); + if (timersForKey != null) { + timersForKey.remove(timer); + if (timersForKey.isEmpty()) { + activeTimers.remove(key); + } else { + activeTimers.put(key, timersForKey); + } + } + } + + /** + * Returns the list of timers that are ready to fire. These are the timers + * that are registered to be triggered at a time before the current watermark. + * We keep these timers in a Set, so that they are deduplicated, as the same + * timer can be registered multiple times. + */ + private Set<TimerOrElement> getTimersReadyToProcess(long currentWatermark) { + + // we keep the timers to return in a different list and launch them later + // because we cannot prevent a trigger from registering another trigger, + // which would lead to concurrent modification exception. + Set<TimerOrElement> toFire = new HashSet<>(); + + Iterator<Map.Entry<K, Set<TimerInternals.TimerData>>> it = activeTimers.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry<K, Set<TimerInternals.TimerData>> keyWithTimers = it.next(); + + Iterator<TimerInternals.TimerData> timerIt = keyWithTimers.getValue().iterator(); + while (timerIt.hasNext()) { + TimerInternals.TimerData timerData = timerIt.next(); + if (timerData.getTimestamp().isBefore(currentWatermark)) { + TimerOrElement timer = TimerOrElement.timer(keyWithTimers.getKey(), timerData); + toFire.add(timer); + timerIt.remove(); + } + } + + if (keyWithTimers.getValue().isEmpty()) { + it.remove(); + } + } + return toFire; + } + + /** + * Gets the state associated with the specified key. + * + * @param key the key whose state we want. + * @return The {@link FlinkStateInternals} + * associated with that key. + */ + private FlinkStateInternals<K> getStateInternalsForKey(K key) { + FlinkStateInternals<K> stateInternals = perKeyStateInternals.get(key); + if (stateInternals == null) { + Coder<? extends BoundedWindow> windowCoder = this.windowingStrategy.getWindowFn().windowCoder(); + stateInternals = new FlinkStateInternals<>(key, inputKvCoder.getKeyCoder(), windowCoder, combineFn); + perKeyStateInternals.put(key, stateInternals); + } + return stateInternals; + } + + private class FlinkTimerInternals extends AbstractFlinkTimerInternals<K, VIN> { + + @Override + protected void registerTimer(K key, TimerData timerKey) { + registerActiveTimer(key, timerKey); + } + + @Override + protected void unregisterTimer(K key, TimerData timerKey) { + unregisterActiveTimer(key, timerKey); + } + } + + private class ProcessContext extends DoFn<TimerOrElement<WindowedValue<KV<K, VIN>>>, KV<K, VOUT>>.ProcessContext { + + private final FlinkTimerInternals timerInternals; + + private final DoFn<TimerOrElement<WindowedValue<KV<K, VIN>>>, KV<K, VOUT>> fn; + + private final Collector<WindowedValue<KV<K, VOUT>>> collector; + + private FlinkStateInternals<K> stateInternals; + + private TimerOrElement<WindowedValue<KV<K, VIN>>> element; + + public ProcessContext(DoFn<TimerOrElement<WindowedValue<KV<K, VIN>>>, KV<K, VOUT>> function, + Collector<WindowedValue<KV<K, VOUT>>> outCollector, + FlinkTimerInternals timerInternals) { + function.super(); + super.setupDelegateAggregators(); + + this.fn = Preconditions.checkNotNull(function); + this.collector = Preconditions.checkNotNull(outCollector); + this.timerInternals = Preconditions.checkNotNull(timerInternals); + } + + public void setElement(TimerOrElement<WindowedValue<KV<K, VIN>>> value, + FlinkStateInternals<K> stateForKey) { + this.element = value; + this.stateInternals = stateForKey; + this.timerInternals.setElement(value); + } + + public void setCurrentWatermark(Instant watermark) { + this.timerInternals.setCurrentWatermark(watermark); + } + + @Override + public TimerOrElement element() { + if (element != null && !this.element.isTimer()) { + return TimerOrElement.element(this.element.element().getValue()); + } + return this.element; + } + + @Override + public Instant timestamp() { + return this.element.isTimer() ? + this.element.getTimer().getTimestamp() : + this.element.element().getTimestamp(); + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public void output(KV<K, VOUT> output) { + throw new UnsupportedOperationException( + "output() is not available when grouping by window."); + } + + @Override + public void outputWithTimestamp(KV<K, VOUT> output, Instant timestamp) { + throw new UnsupportedOperationException( + "outputWithTimestamp() is not available when grouping by window."); + } + + @Override + public PaneInfo pane() { + return this.element.element().getPane(); + } + + @Override + public BoundedWindow window() { + if (!(fn instanceof DoFn.RequiresWindowAccess)) { + throw new UnsupportedOperationException( + "window() is only available in the context of a DoFn marked as RequiresWindow."); + } + + Collection<? extends BoundedWindow> windows = this.element.element().getWindows(); + if (windows.size() != 1) { + throw new IllegalArgumentException("Each element is expected to belong to 1 window. " + + "This belongs to " + windows.size() + "."); + } + return windows.iterator().next(); + } + + @Override + public WindowingInternals<TimerOrElement<WindowedValue<KV<K, VIN>>>, KV<K, VOUT>> windowingInternals() { + return new WindowingInternals<TimerOrElement<WindowedValue<KV<K, VIN>>>, KV<K, VOUT>>() { + + @Override + public com.google.cloud.dataflow.sdk.util.state.StateInternals stateInternals() { + return stateInternals; + } + + @Override + public void outputWindowedValue(KV<K, VOUT> output, Instant timestamp, Collection<? extends BoundedWindow> windows, PaneInfo pane) { + collector.collect(WindowedValue.of(output, timestamp, windows, pane)); + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + + @Override + public Collection<? extends BoundedWindow> windows() { + return element.element().getWindows(); + } + + @Override + public PaneInfo pane() { + return element.element().getPane(); + } + + @Override + public <T> void writePCollectionViewData(TupleTag<?> tag, Iterable<WindowedValue<T>> data, Coder<T> elemCoder) throws IOException { + throw new RuntimeException("writePCollectionViewData() not supported in Streaming mode."); + } + }; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + throw new RuntimeException("sideInput() is not supported in Streaming mode."); + } + + @Override + public <T> void sideOutput(TupleTag<T> tag, T output) { + // ignore the side output, this can happen when a user does not register + // side outputs but then outputs using a freshly created TupleTag. + throw new RuntimeException("sideOutput() is not available when grouping by window."); + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal(String name, Combine.CombineFn<AggInputT, ?, AggOutputT> combiner) { + Accumulator acc = getRuntimeContext().getAccumulator(name); + if (acc != null) { + AccumulatorHelper.compareAccumulatorTypes(name, + SerializableFnAggregatorWrapper.class, acc.getClass()); + return (Aggregator<AggInputT, AggOutputT>) acc; + } + + SerializableFnAggregatorWrapper<AggInputT, AggOutputT> accumulator = + new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, accumulator); + return accumulator; + } + } + + ////////////// Checkpointing implementation //////////////// + + @Override + public StreamTaskState snapshotOperatorState(long checkpointId, long timestamp) throws Exception { + StreamTaskState taskState = super.snapshotOperatorState(checkpointId, timestamp); + StateBackend.CheckpointStateOutputView out = getStateBackend().createCheckpointStateOutputView(checkpointId, timestamp); + StateCheckpointWriter writer = StateCheckpointWriter.create(out); + Coder<K> keyCoder = inputKvCoder.getKeyCoder(); + + // checkpoint the timers + StateCheckpointUtils.encodeTimers(activeTimers, writer, keyCoder); + + // checkpoint the state + StateCheckpointUtils.encodeState(perKeyStateInternals, writer, keyCoder); + + // checkpoint the timerInternals + context.timerInternals.encodeTimerInternals(context, writer, + inputKvCoder, windowingStrategy.getWindowFn().windowCoder()); + + taskState.setOperatorState(out.closeAndGetHandle()); + return taskState; + } + + @Override + public void restoreState(StreamTaskState taskState) throws Exception { + super.restoreState(taskState); + + final ClassLoader userClassloader = getUserCodeClassloader(); + + Coder<? extends BoundedWindow> windowCoder = this.windowingStrategy.getWindowFn().windowCoder(); + Coder<K> keyCoder = inputKvCoder.getKeyCoder(); + + @SuppressWarnings("unchecked") + StateHandle<DataInputView> inputState = (StateHandle<DataInputView>) taskState.getOperatorState(); + DataInputView in = inputState.getState(userClassloader); + StateCheckpointReader reader = new StateCheckpointReader(in); + + // restore the timers + this.activeTimers = StateCheckpointUtils.decodeTimers(reader, windowCoder, keyCoder); + + // restore the state + this.perKeyStateInternals = StateCheckpointUtils.decodeState( + reader, combineFn, keyCoder, windowCoder, userClassloader); + + // restore the timerInternals. + this.timerInternals.restoreTimerInternals(reader, inputKvCoder, windowCoder); + } +} \ No newline at end of file