http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java new file mode 100644 index 0000000..bd8a968 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputDoFnFunction.java @@ -0,0 +1,175 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.runners.flink.translation.wrappers.SerializableFnAggregatorWrapper; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.transforms.windowing.BoundedWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindow; +import com.google.cloud.dataflow.sdk.transforms.windowing.PaneInfo; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingInternals; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.TupleTag; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.util.Collector; +import org.joda.time.Instant; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * Encapsulates a {@link com.google.cloud.dataflow.sdk.transforms.DoFn} that uses side outputs + * inside a Flink {@link org.apache.flink.api.common.functions.RichMapPartitionFunction}. + * + * We get a mapping from {@link com.google.cloud.dataflow.sdk.values.TupleTag} to output index + * and must tag all outputs with the output number. Afterwards a filter will filter out + * those elements that are not to be in a specific output. + */ +public class FlinkMultiOutputDoFnFunction<IN, OUT> extends RichMapPartitionFunction<IN, RawUnionValue> { + + private final DoFn<IN, OUT> doFn; + private transient PipelineOptions options; + private final Map<TupleTag<?>, Integer> outputMap; + + public FlinkMultiOutputDoFnFunction(DoFn<IN, OUT> doFn, PipelineOptions options, Map<TupleTag<?>, Integer> outputMap) { + this.doFn = doFn; + this.options = options; + this.outputMap = outputMap; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + + } + + @Override + public void mapPartition(Iterable<IN> values, Collector<RawUnionValue> out) throws Exception { + ProcessContext context = new ProcessContext(doFn, out); + this.doFn.startBundle(context); + for (IN value : values) { + context.inValue = value; + doFn.processElement(context); + } + this.doFn.finishBundle(context); + } + + private class ProcessContext extends DoFn<IN, OUT>.ProcessContext { + + IN inValue; + Collector<RawUnionValue> outCollector; + + public ProcessContext(DoFn<IN, OUT> fn, Collector<RawUnionValue> outCollector) { + fn.super(); + this.outCollector = outCollector; + } + + @Override + public IN element() { + return this.inValue; + } + + @Override + public Instant timestamp() { + return Instant.now(); + } + + @Override + public BoundedWindow window() { + return GlobalWindow.INSTANCE; + } + + @Override + public PaneInfo pane() { + return PaneInfo.NO_FIRING; + } + + @Override + public WindowingInternals<IN, OUT> windowingInternals() { + return null; + } + + @Override + public PipelineOptions getPipelineOptions() { + return options; + } + + @Override + public <T> T sideInput(PCollectionView<T> view) { + List<T> sideInput = getRuntimeContext().getBroadcastVariable(view.getTagInternal() + .getId()); + List<WindowedValue<?>> windowedValueList = new ArrayList<>(sideInput.size()); + for (T input : sideInput) { + windowedValueList.add(WindowedValue.of(input, Instant.now(), ImmutableList.of(GlobalWindow.INSTANCE), pane())); + } + return view.fromIterableInternal(windowedValueList); + } + + @Override + public void output(OUT value) { + // assume that index 0 is the default output + outCollector.collect(new RawUnionValue(0, value)); + } + + @Override + public void outputWithTimestamp(OUT output, Instant timestamp) { + // not FLink's way, just output normally + output(output); + } + + @Override + @SuppressWarnings("unchecked") + public <T> void sideOutput(TupleTag<T> tag, T value) { + Integer index = outputMap.get(tag); + if (index != null) { + outCollector.collect(new RawUnionValue(index, value)); + } + } + + @Override + public <T> void sideOutputWithTimestamp(TupleTag<T> tag, T output, Instant timestamp) { + sideOutput(tag, output); + } + + @Override + protected <AggInputT, AggOutputT> Aggregator<AggInputT, AggOutputT> createAggregatorInternal(String name, Combine.CombineFn<AggInputT, ?, AggOutputT> combiner) { + SerializableFnAggregatorWrapper<AggInputT, AggOutputT> wrapper = new SerializableFnAggregatorWrapper<>(combiner); + getRuntimeContext().addAccumulator(name, wrapper); + return null; + } + + } +}
http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java new file mode 100644 index 0000000..3e1cb65 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkMultiOutputPruningFunction.java @@ -0,0 +1,41 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +/** + * A FlatMap function that filters out those elements that don't belong in this output. We need + * this to implement MultiOutput ParDo functions. + */ +public class FlinkMultiOutputPruningFunction<T> implements FlatMapFunction<RawUnionValue, T> { + + private final int outputTag; + + public FlinkMultiOutputPruningFunction(int outputTag) { + this.outputTag = outputTag; + } + + @Override + @SuppressWarnings("unchecked") + public void flatMap(RawUnionValue rawUnionValue, Collector<T> collector) throws Exception { + if (rawUnionValue.getUnionTag() == outputTag) { + collector.collect((T) rawUnionValue.getValue()); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java new file mode 100644 index 0000000..1ff06ba --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkPartialReduceFunction.java @@ -0,0 +1,60 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.functions.GroupCombineFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupCombineFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.Combine.PerKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements VI, extracts the key and emits accumulated + * values which have the intermediate format VA. + */ +public class FlinkPartialReduceFunction<K, VI, VA> implements GroupCombineFunction<KV<K, VI>, KV<K, VA>> { + + private final Combine.KeyedCombineFn<K, VI, VA, ?> keyedCombineFn; + + public FlinkPartialReduceFunction(Combine.KeyedCombineFn<K, VI, VA, ?> + keyedCombineFn) { + this.keyedCombineFn = keyedCombineFn; + } + + @Override + public void combine(Iterable<KV<K, VI>> elements, Collector<KV<K, VA>> out) throws Exception { + + final Iterator<KV<K, VI>> iterator = elements.iterator(); + // create accumulator using the first elements key + KV<K, VI> first = iterator.next(); + K key = first.getKey(); + VI value = first.getValue(); + VA accumulator = keyedCombineFn.createAccumulator(key); + accumulator = keyedCombineFn.addInput(key, accumulator, value); + + while(iterator.hasNext()) { + value = iterator.next().getValue(); + accumulator = keyedCombineFn.addInput(key, accumulator, value); + } + + out.collect(KV.of(key, accumulator)); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java new file mode 100644 index 0000000..580ac01 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkReduceFunction.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.functions; + +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.common.collect.ImmutableList; +import org.apache.flink.api.common.functions.GroupReduceFunction; +import org.apache.flink.util.Collector; + +import java.util.Iterator; + +/** + * Flink {@link org.apache.flink.api.common.functions.GroupReduceFunction} for executing a + * {@link com.google.cloud.dataflow.sdk.transforms.Combine.PerKey} operation. This reads the input + * {@link com.google.cloud.dataflow.sdk.values.KV} elements, extracts the key and merges the + * accumulators resulting from the PartialReduce which produced the input VA. + */ +public class FlinkReduceFunction<K, VA, VO> implements GroupReduceFunction<KV<K, VA>, KV<K, VO>> { + + private final Combine.KeyedCombineFn<K, ?, VA, VO> keyedCombineFn; + + public FlinkReduceFunction(Combine.KeyedCombineFn<K, ?, VA, VO> keyedCombineFn) { + this.keyedCombineFn = keyedCombineFn; + } + + @Override + public void reduce(Iterable<KV<K, VA>> values, Collector<KV<K, VO>> out) throws Exception { + Iterator<KV<K, VA>> it = values.iterator(); + + KV<K, VA> current = it.next(); + K k = current.getKey(); + VA accumulator = current.getValue(); + + while (it.hasNext()) { + current = it.next(); + keyedCombineFn.mergeAccumulators(k, ImmutableList.of(accumulator, current.getValue()) ); + } + + out.collect(KV.of(k, keyedCombineFn.extractOutput(k, accumulator))); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java new file mode 100644 index 0000000..05f4415 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/UnionCoder.java @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2015 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.beam.runners.flink.translation.functions; + + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.StandardCoder; +import com.google.cloud.dataflow.sdk.transforms.join.RawUnionValue; +import com.google.cloud.dataflow.sdk.util.PropertyNames; +import com.google.cloud.dataflow.sdk.util.VarInt; +import com.google.cloud.dataflow.sdk.util.common.ElementByteSizeObserver; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.List; + +/** + * A UnionCoder encodes RawUnionValues. + * + * This file copied from {@link com.google.cloud.dataflow.sdk.transforms.join.UnionCoder} + */ +@SuppressWarnings("serial") +public class UnionCoder extends StandardCoder<RawUnionValue> { + // TODO: Think about how to integrate this with a schema object (i.e. + // a tuple of tuple tags). + /** + * Builds a union coder with the given list of element coders. This list + * corresponds to a mapping of union tag to Coder. Union tags start at 0. + */ + public static UnionCoder of(List<Coder<?>> elementCoders) { + return new UnionCoder(elementCoders); + } + + @JsonCreator + public static UnionCoder jsonOf( + @JsonProperty(PropertyNames.COMPONENT_ENCODINGS) + List<Coder<?>> elements) { + return UnionCoder.of(elements); + } + + private int getIndexForEncoding(RawUnionValue union) { + if (union == null) { + throw new IllegalArgumentException("cannot encode a null tagged union"); + } + int index = union.getUnionTag(); + if (index < 0 || index >= elementCoders.size()) { + throw new IllegalArgumentException( + "union value index " + index + " not in range [0.." + + (elementCoders.size() - 1) + "]"); + } + return index; + } + + @SuppressWarnings("unchecked") + @Override + public void encode( + RawUnionValue union, + OutputStream outStream, + Context context) + throws IOException { + int index = getIndexForEncoding(union); + // Write out the union tag. + VarInt.encode(index, outStream); + + // Write out the actual value. + Coder<Object> coder = (Coder<Object>) elementCoders.get(index); + coder.encode( + union.getValue(), + outStream, + context); + } + + @Override + public RawUnionValue decode(InputStream inStream, Context context) + throws IOException { + int index = VarInt.decodeInt(inStream); + Object value = elementCoders.get(index).decode(inStream, context); + return new RawUnionValue(index, value); + } + + @Override + public List<? extends Coder<?>> getCoderArguments() { + return null; + } + + @Override + public List<? extends Coder<?>> getComponents() { + return elementCoders; + } + + /** + * Since this coder uses elementCoders.get(index) and coders that are known to run in constant + * time, we defer the return value to that coder. + */ + @Override + public boolean isRegisterByteSizeObserverCheap(RawUnionValue union, Context context) { + int index = getIndexForEncoding(union); + @SuppressWarnings("unchecked") + Coder<Object> coder = (Coder<Object>) elementCoders.get(index); + return coder.isRegisterByteSizeObserverCheap(union.getValue(), context); + } + + /** + * Notifies ElementByteSizeObserver about the byte size of the encoded value using this coder. + */ + @Override + public void registerByteSizeObserver( + RawUnionValue union, ElementByteSizeObserver observer, Context context) + throws Exception { + int index = getIndexForEncoding(union); + // Write out the union tag. + observer.update(VarInt.getLength(index)); + // Write out the actual value. + @SuppressWarnings("unchecked") + Coder<Object> coder = (Coder<Object>) elementCoders.get(index); + coder.registerByteSizeObserver(union.getValue(), observer, context); + } + + ///////////////////////////////////////////////////////////////////////////// + + private final List<Coder<?>> elementCoders; + + private UnionCoder(List<Coder<?>> elementCoders) { + this.elementCoders = elementCoders; + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + "UnionCoder is only deterministic if all element coders are", + elementCoders); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java new file mode 100644 index 0000000..ecfb95d --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderComparator.java @@ -0,0 +1,216 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemorySegment; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for + * {@link com.google.cloud.dataflow.sdk.coders.Coder}. + */ +public class CoderComparator<T> extends TypeComparator<T> { + + private Coder<T> coder; + + // We use these for internal encoding/decoding for creating copies and comparing + // serialized forms using a Coder + private transient InspectableByteArrayOutputStream buffer1; + private transient InspectableByteArrayOutputStream buffer2; + + // For storing the Reference in encoded form + private transient InspectableByteArrayOutputStream referenceBuffer; + + public CoderComparator(Coder<T> coder) { + this.coder = coder; + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + } + + @Override + public int hash(T record) { + return record.hashCode(); + } + + @Override + public void setReference(T toCompare) { + referenceBuffer.reset(); + try { + coder.encode(toCompare, referenceBuffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not set reference " + toCompare + ": " + e); + } + } + + @Override + public boolean equalToReference(T candidate) { + try { + buffer2.reset(); + coder.encode(candidate, buffer2, Coder.Context.OUTER); + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (referenceBuffer.size() != buffer2.size()) { + return false; + } + int len = buffer2.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return false; + } + } + return true; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareToReference(TypeComparator<T> other) { + InspectableByteArrayOutputStream otherReferenceBuffer = ((CoderComparator<T>) other).referenceBuffer; + + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = otherReferenceBuffer.getBuffer(); + if (referenceBuffer.size() != otherReferenceBuffer.size()) { + return referenceBuffer.size() - otherReferenceBuffer.size(); + } + int len = referenceBuffer.size(); + for (int i = 0; i < len; i++) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } + + @Override + public int compare(T first, T second) { + try { + buffer1.reset(); + buffer2.reset(); + coder.encode(first, buffer1, Coder.Context.OUTER); + coder.encode(second, buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare: ", e); + } + } + + @Override + public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { + CoderTypeSerializer<T> serializer = new CoderTypeSerializer<>(coder); + T first = serializer.deserialize(firstSource); + T second = serializer.deserialize(secondSource); + return compare(first, second); + } + + @Override + public boolean supportsNormalizedKey() { + return true; + } + + @Override + public boolean supportsSerializationWithKeyNormalization() { + return false; + } + + @Override + public int getNormalizeKeyLen() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isNormalizedKeyPrefixOnly(int keyBytes) { + return true; + } + + @Override + public void putNormalizedKey(T record, MemorySegment target, int offset, int numBytes) { + buffer1.reset(); + try { + coder.encode(record, buffer1, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + } + final byte[] data = buffer1.getBuffer(); + final int limit = offset + numBytes; + + target.put(offset, data, 0, Math.min(numBytes, buffer1.size())); + + offset += buffer1.size(); + + while (offset < limit) { + target.put(offset++, (byte) 0); + } + } + + @Override + public void writeWithKeyNormalization(T record, DataOutputView target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public T readWithKeyDenormalization(T reuse, DataInputView source) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean invertNormalizedKey() { + return false; + } + + @Override + public TypeComparator<T> duplicate() { + return new CoderComparator<>(coder); + } + + @Override + public int extractKeys(Object record, Object[] target, int index) { + target[index] = record; + return 1; + } + + @Override + public TypeComparator[] getFlatComparators() { + return new TypeComparator[] { this.duplicate() }; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java new file mode 100644 index 0000000..8880b48 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeInformation.java @@ -0,0 +1,116 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.VoidCoder; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.AtomicType; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import com.google.common.base.Preconditions; + +/** + * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s. + */ +public class CoderTypeInformation<T> extends TypeInformation<T> implements AtomicType<T> { + + private final Coder<T> coder; + + public CoderTypeInformation(Coder<T> coder) { + Preconditions.checkNotNull(coder); + this.coder = coder; + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 1; + } + + @Override + @SuppressWarnings("unchecked") + public Class<T> getTypeClass() { + // We don't have the Class, so we have to pass null here. What a shame... + return (Class<T>) Object.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + public TypeSerializer<T> createSerializer(ExecutionConfig config) { + if (coder instanceof VoidCoder) { + return (TypeSerializer<T>) new VoidCoderTypeSerializer(); + } + return new CoderTypeSerializer<>(coder); + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CoderTypeInformation that = (CoderTypeInformation) o; + + return coder.equals(that.coder); + + } + + @Override + public int hashCode() { + return coder.hashCode(); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CoderTypeInformation; + } + + @Override + public String toString() { + return "CoderTypeInformation{" + + "coder=" + coder + + '}'; + } + + @Override + public TypeComparator<T> createComparator(boolean sortOrderAscending, ExecutionConfig + executionConfig) { + return new CoderComparator<>(coder); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java new file mode 100644 index 0000000..481ee31 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/CoderTypeSerializer.java @@ -0,0 +1,152 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; +import org.apache.beam.runners.flink.translation.wrappers.DataOutputViewWrapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.CoderException; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s + */ +public class CoderTypeSerializer<T> extends TypeSerializer<T> { + + private Coder<T> coder; + private transient DataInputViewWrapper inputWrapper; + private transient DataOutputViewWrapper outputWrapper; + + // We use this for internal encoding/decoding for creating copies using the Coder. + private transient InspectableByteArrayOutputStream buffer; + + public CoderTypeSerializer(Coder<T> coder) { + this.coder = coder; + this.inputWrapper = new DataInputViewWrapper(null); + this.outputWrapper = new DataOutputViewWrapper(null); + + buffer = new InspectableByteArrayOutputStream(); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + this.inputWrapper = new DataInputViewWrapper(null); + this.outputWrapper = new DataOutputViewWrapper(null); + + buffer = new InspectableByteArrayOutputStream(); + } + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public CoderTypeSerializer<T> duplicate() { + return new CoderTypeSerializer<>(coder); + } + + @Override + public T createInstance() { + return null; + } + + @Override + public T copy(T t) { + buffer.reset(); + try { + coder.encode(t, buffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not copy.", e); + } + try { + return coder.decode(new ByteArrayInputStream(buffer.getBuffer(), 0, buffer + .size()), Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not copy.", e); + } + } + + @Override + public T copy(T t, T reuse) { + return copy(t); + } + + @Override + public int getLength() { + return 0; + } + + @Override + public void serialize(T t, DataOutputView dataOutputView) throws IOException { + outputWrapper.setOutputView(dataOutputView); + coder.encode(t, outputWrapper, Coder.Context.NESTED); + } + + @Override + public T deserialize(DataInputView dataInputView) throws IOException { + try { + inputWrapper.setInputView(dataInputView); + return coder.decode(inputWrapper, Coder.Context.NESTED); + } catch (CoderException e) { + Throwable cause = e.getCause(); + if (cause instanceof EOFException) { + throw (EOFException) cause; + } else { + throw e; + } + } + } + + @Override + public T deserialize(T t, DataInputView dataInputView) throws IOException { + return deserialize(dataInputView); + } + + @Override + public void copy(DataInputView dataInputView, DataOutputView dataOutputView) throws IOException { + serialize(deserialize(dataInputView), dataOutputView); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + CoderTypeSerializer that = (CoderTypeSerializer) o; + return coder.equals(that.coder); + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof CoderTypeSerializer; + } + + @Override + public int hashCode() { + return coder.hashCode(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java new file mode 100644 index 0000000..619fa55 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/InspectableByteArrayOutputStream.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import java.io.ByteArrayOutputStream; + +/** + * Version of {@link java.io.ByteArrayOutputStream} that allows to retrieve the internal + * byte[] buffer without incurring an array copy. + */ +public class InspectableByteArrayOutputStream extends ByteArrayOutputStream { + + /** + * Get the underlying byte array. + */ + public byte[] getBuffer() { + return buf; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java new file mode 100644 index 0000000..4599c6a --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderComperator.java @@ -0,0 +1,264 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.beam.runners.flink.translation.wrappers.DataInputViewWrapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.MemorySegment; + +import java.io.IOException; +import java.io.ObjectInputStream; + +/** + * Flink {@link org.apache.flink.api.common.typeutils.TypeComparator} for + * {@link com.google.cloud.dataflow.sdk.coders.KvCoder}. We have a special comparator + * for {@link KV} that always compares on the key only. + */ +public class KvCoderComperator <K, V> extends TypeComparator<KV<K, V>> { + + private KvCoder<K, V> coder; + private Coder<K> keyCoder; + + // We use these for internal encoding/decoding for creating copies and comparing + // serialized forms using a Coder + private transient InspectableByteArrayOutputStream buffer1; + private transient InspectableByteArrayOutputStream buffer2; + + // For storing the Reference in encoded form + private transient InspectableByteArrayOutputStream referenceBuffer; + + + // For deserializing the key + private transient DataInputViewWrapper inputWrapper; + + public KvCoderComperator(KvCoder<K, V> coder) { + this.coder = coder; + this.keyCoder = coder.getKeyCoder(); + + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + + inputWrapper = new DataInputViewWrapper(null); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + + buffer1 = new InspectableByteArrayOutputStream(); + buffer2 = new InspectableByteArrayOutputStream(); + referenceBuffer = new InspectableByteArrayOutputStream(); + + inputWrapper = new DataInputViewWrapper(null); + } + + @Override + public int hash(KV<K, V> record) { + K key = record.getKey(); + if (key != null) { + return key.hashCode(); + } else { + return 0; + } + } + + @Override + public void setReference(KV<K, V> toCompare) { + referenceBuffer.reset(); + try { + keyCoder.encode(toCompare.getKey(), referenceBuffer, Coder.Context.OUTER); + } catch (IOException e) { + throw new RuntimeException("Could not set reference " + toCompare + ": " + e); + } + } + + @Override + public boolean equalToReference(KV<K, V> candidate) { + try { + buffer2.reset(); + keyCoder.encode(candidate.getKey(), buffer2, Coder.Context.OUTER); + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (referenceBuffer.size() != buffer2.size()) { + return false; + } + int len = buffer2.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return false; + } + } + return true; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareToReference(TypeComparator<KV<K, V>> other) { + InspectableByteArrayOutputStream otherReferenceBuffer = ((KvCoderComperator<K, V>) other).referenceBuffer; + + byte[] arr = referenceBuffer.getBuffer(); + byte[] arrOther = otherReferenceBuffer.getBuffer(); + if (referenceBuffer.size() != otherReferenceBuffer.size()) { + return referenceBuffer.size() - otherReferenceBuffer.size(); + } + int len = referenceBuffer.size(); + for (int i = 0; i < len; i++) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } + + + @Override + public int compare(KV<K, V> first, KV<K, V> second) { + try { + buffer1.reset(); + buffer2.reset(); + keyCoder.encode(first.getKey(), buffer1, Coder.Context.OUTER); + keyCoder.encode(second.getKey(), buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public int compareSerialized(DataInputView firstSource, DataInputView secondSource) throws IOException { + + inputWrapper.setInputView(firstSource); + K firstKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); + inputWrapper.setInputView(secondSource); + K secondKey = keyCoder.decode(inputWrapper, Coder.Context.NESTED); + + try { + buffer1.reset(); + buffer2.reset(); + keyCoder.encode(firstKey, buffer1, Coder.Context.OUTER); + keyCoder.encode(secondKey, buffer2, Coder.Context.OUTER); + byte[] arr = buffer1.getBuffer(); + byte[] arrOther = buffer2.getBuffer(); + if (buffer1.size() != buffer2.size()) { + return buffer1.size() - buffer2.size(); + } + int len = buffer1.size(); + for(int i = 0; i < len; i++ ) { + if (arr[i] != arrOther[i]) { + return arr[i] - arrOther[i]; + } + } + return 0; + } catch (IOException e) { + throw new RuntimeException("Could not compare reference.", e); + } + } + + @Override + public boolean supportsNormalizedKey() { + return true; + } + + @Override + public boolean supportsSerializationWithKeyNormalization() { + return false; + } + + @Override + public int getNormalizeKeyLen() { + return Integer.MAX_VALUE; + } + + @Override + public boolean isNormalizedKeyPrefixOnly(int keyBytes) { + return true; + } + + @Override + public void putNormalizedKey(KV<K, V> record, MemorySegment target, int offset, int numBytes) { + buffer1.reset(); + try { + keyCoder.encode(record.getKey(), buffer1, Coder.Context.NESTED); + } catch (IOException e) { + throw new RuntimeException("Could not serializer " + record + " using coder " + coder + ": " + e); + } + final byte[] data = buffer1.getBuffer(); + final int limit = offset + numBytes; + + int numBytesPut = Math.min(numBytes, buffer1.size()); + + target.put(offset, data, 0, numBytesPut); + + offset += numBytesPut; + + while (offset < limit) { + target.put(offset++, (byte) 0); + } + } + + @Override + public void writeWithKeyNormalization(KV<K, V> record, DataOutputView target) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public KV<K, V> readWithKeyDenormalization(KV<K, V> reuse, DataInputView source) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public boolean invertNormalizedKey() { + return false; + } + + @Override + public TypeComparator<KV<K, V>> duplicate() { + return new KvCoderComperator<>(coder); + } + + @Override + public int extractKeys(Object record, Object[] target, int index) { + KV<K, V> kv = (KV<K, V>) record; + K k = kv.getKey(); + target[index] = k; + return 1; + } + + @Override + public TypeComparator[] getFlatComparators() { + return new TypeComparator[] {new CoderComparator<>(keyCoder)}; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java new file mode 100644 index 0000000..7a0d999 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/KvCoderTypeInformation.java @@ -0,0 +1,186 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import com.google.cloud.dataflow.sdk.coders.KvCoder; +import com.google.cloud.dataflow.sdk.values.KV; +import org.apache.flink.api.common.ExecutionConfig; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeutils.CompositeType; +import org.apache.flink.api.common.typeutils.TypeComparator; +import org.apache.flink.api.common.typeutils.TypeSerializer; +import com.google.common.base.Preconditions; + +import java.util.List; + +/** + * Flink {@link org.apache.flink.api.common.typeinfo.TypeInformation} for + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.KvCoder}. + */ +public class KvCoderTypeInformation<K, V> extends CompositeType<KV<K, V>> { + + private KvCoder<K, V> coder; + + // We don't have the Class, so we have to pass null here. What a shame... + private static Object DUMMY = new Object(); + + @SuppressWarnings("unchecked") + public KvCoderTypeInformation(KvCoder<K, V> coder) { + super(((Class<KV<K,V>>) DUMMY.getClass())); + this.coder = coder; + Preconditions.checkNotNull(coder); + } + + @Override + @SuppressWarnings("unchecked") + public TypeComparator<KV<K, V>> createComparator(int[] logicalKeyFields, boolean[] orders, int logicalFieldOffset, ExecutionConfig config) { + return new KvCoderComperator((KvCoder) coder); + } + + @Override + public boolean isBasicType() { + return false; + } + + @Override + public boolean isTupleType() { + return false; + } + + @Override + public int getArity() { + return 2; + } + + @Override + @SuppressWarnings("unchecked") + public Class<KV<K, V>> getTypeClass() { + return privateGetTypeClass(); + } + + @SuppressWarnings("unchecked") + private static <X> Class<X> privateGetTypeClass() { + return (Class<X>) Object.class; + } + + @Override + public boolean isKeyType() { + return true; + } + + @Override + @SuppressWarnings("unchecked") + public TypeSerializer<KV<K, V>> createSerializer(ExecutionConfig config) { + return new CoderTypeSerializer<>(coder); + } + + @Override + public int getTotalFields() { + return 2; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + KvCoderTypeInformation that = (KvCoderTypeInformation) o; + + return coder.equals(that.coder); + + } + + @Override + public int hashCode() { + return coder.hashCode(); + } + + @Override + public String toString() { + return "CoderTypeInformation{" + + "coder=" + coder + + '}'; + } + + @Override + @SuppressWarnings("unchecked") + public <X> TypeInformation<X> getTypeAt(int pos) { + if (pos == 0) { + return (TypeInformation<X>) new CoderTypeInformation<>(coder.getKeyCoder()); + } else if (pos == 1) { + return (TypeInformation<X>) new CoderTypeInformation<>(coder.getValueCoder()); + } else { + throw new RuntimeException("Invalid field position " + pos); + } + } + + @Override + @SuppressWarnings("unchecked") + public <X> TypeInformation<X> getTypeAt(String fieldExpression) { + switch (fieldExpression) { + case "key": + return (TypeInformation<X>) new CoderTypeInformation<>(coder.getKeyCoder()); + case "value": + return (TypeInformation<X>) new CoderTypeInformation<>(coder.getValueCoder()); + default: + throw new UnsupportedOperationException("Only KvCoder has fields."); + } + } + + @Override + public String[] getFieldNames() { + return new String[]{"key", "value"}; + } + + @Override + public int getFieldIndex(String fieldName) { + switch (fieldName) { + case "key": + return 0; + case "value": + return 1; + default: + return -1; + } + } + + @Override + public void getFlatFields(String fieldExpression, int offset, List<FlatFieldDescriptor> result) { + CoderTypeInformation keyTypeInfo = new CoderTypeInformation<>(coder.getKeyCoder()); + result.add(new FlatFieldDescriptor(0, keyTypeInfo)); + } + + @Override + protected TypeComparatorBuilder<KV<K, V>> createTypeComparatorBuilder() { + return new KvCoderTypeComparatorBuilder(); + } + + private class KvCoderTypeComparatorBuilder implements TypeComparatorBuilder<KV<K, V>> { + + @Override + public void initializeTypeComparatorBuilder(int size) {} + + @Override + public void addComparatorField(int fieldId, TypeComparator<?> comparator) {} + + @Override + public TypeComparator<KV<K, V>> createTypeComparator(ExecutionConfig config) { + return new KvCoderComperator<>(coder); + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java new file mode 100644 index 0000000..c7b6ea2 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/types/VoidCoderTypeSerializer.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.types; + +import org.apache.flink.api.common.typeutils.TypeSerializer; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; + +/** + * Special Flink {@link org.apache.flink.api.common.typeutils.TypeSerializer} for + * {@link com.google.cloud.dataflow.sdk.coders.VoidCoder}. We need this because Flink does not + * allow returning {@code null} from an input reader. We return a {@link VoidValue} instead + * that behaves like a {@code null}, hopefully. + */ +public class VoidCoderTypeSerializer extends TypeSerializer<VoidCoderTypeSerializer.VoidValue> { + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public VoidCoderTypeSerializer duplicate() { + return this; + } + + @Override + public VoidValue createInstance() { + return VoidValue.INSTANCE; + } + + @Override + public VoidValue copy(VoidValue from) { + return from; + } + + @Override + public VoidValue copy(VoidValue from, VoidValue reuse) { + return from; + } + + @Override + public int getLength() { + return 0; + } + + @Override + public void serialize(VoidValue record, DataOutputView target) throws IOException { + target.writeByte(1); + } + + @Override + public VoidValue deserialize(DataInputView source) throws IOException { + source.readByte(); + return VoidValue.INSTANCE; + } + + @Override + public VoidValue deserialize(VoidValue reuse, DataInputView source) throws IOException { + return deserialize(source); + } + + @Override + public void copy(DataInputView source, DataOutputView target) throws IOException { + source.readByte(); + target.writeByte(1); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof VoidCoderTypeSerializer) { + VoidCoderTypeSerializer other = (VoidCoderTypeSerializer) obj; + return other.canEqual(this); + } else { + return false; + } + } + + @Override + public boolean canEqual(Object obj) { + return obj instanceof VoidCoderTypeSerializer; + } + + @Override + public int hashCode() { + return 0; + } + + public static class VoidValue { + private VoidValue() {} + + public static VoidValue INSTANCE = new VoidValue(); + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java new file mode 100644 index 0000000..815765c --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/CombineFnAggregatorWrapper.java @@ -0,0 +1,92 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.accumulators.Accumulator; + +import java.io.Serializable; + +/** + * Wrapper that wraps a {@link com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn} + * in a Flink {@link org.apache.flink.api.common.accumulators.Accumulator} for using + * the combine function as an aggregator in a {@link com.google.cloud.dataflow.sdk.transforms.ParDo} + * operation. + */ +public class CombineFnAggregatorWrapper<AI, AA, AR> implements Aggregator<AI, AR>, Accumulator<AI, Serializable> { + + private AA aa; + private Combine.CombineFn<? super AI, AA, AR> combiner; + + public CombineFnAggregatorWrapper() { + } + + public CombineFnAggregatorWrapper(Combine.CombineFn<? super AI, AA, AR> combiner) { + this.combiner = combiner; + this.aa = combiner.createAccumulator(); + } + + @Override + public void add(AI value) { + combiner.addInput(aa, value); + } + + @Override + public Serializable getLocalValue() { + return (Serializable) combiner.extractOutput(aa); + } + + @Override + public void resetLocal() { + aa = combiner.createAccumulator(); + } + + @Override + @SuppressWarnings("unchecked") + public void merge(Accumulator<AI, Serializable> other) { + aa = combiner.mergeAccumulators(Lists.newArrayList(aa, ((CombineFnAggregatorWrapper<AI, AA, AR>)other).aa)); + } + + @Override + public Accumulator<AI, Serializable> clone() { + // copy it by merging + AA aaCopy = combiner.mergeAccumulators(Lists.newArrayList(aa)); + CombineFnAggregatorWrapper<AI, AA, AR> result = new + CombineFnAggregatorWrapper<>(combiner); + result.aa = aaCopy; + return result; + } + + @Override + public void addValue(AI value) { + add(value); + } + + @Override + public String getName() { + return "CombineFn: " + combiner.toString(); + } + + @Override + public Combine.CombineFn getCombineFn() { + return combiner; + } + +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java new file mode 100644 index 0000000..b56a90e --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataInputViewWrapper.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import org.apache.flink.core.memory.DataInputView; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; + +/** + * Wrapper for {@link DataInputView}. We need this because Flink reads data using a + * {@link org.apache.flink.core.memory.DataInputView} while + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s expect an + * {@link java.io.InputStream}. + */ +public class DataInputViewWrapper extends InputStream { + + private DataInputView inputView; + + public DataInputViewWrapper(DataInputView inputView) { + this.inputView = inputView; + } + + public void setInputView(DataInputView inputView) { + this.inputView = inputView; + } + + @Override + public int read() throws IOException { + try { + return inputView.readUnsignedByte(); + } catch (EOFException e) { + // translate between DataInput and InputStream, + // DataInput signals EOF by exception, InputStream does it by returning -1 + return -1; + } + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return inputView.read(b, off, len); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java new file mode 100644 index 0000000..513d7f8 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/DataOutputViewWrapper.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import org.apache.flink.core.memory.DataOutputView; + +import java.io.IOException; +import java.io.OutputStream; + +/** + * Wrapper for {@link org.apache.flink.core.memory.DataOutputView}. We need this because + * Flink writes data using a {@link org.apache.flink.core.memory.DataInputView} while + * Dataflow {@link com.google.cloud.dataflow.sdk.coders.Coder}s expect an + * {@link java.io.OutputStream}. + */ +public class DataOutputViewWrapper extends OutputStream { + + private DataOutputView outputView; + + public DataOutputViewWrapper(DataOutputView outputView) { + this.outputView = outputView; + } + + public void setOutputView(DataOutputView outputView) { + this.outputView = outputView; + } + + @Override + public void write(int b) throws IOException { + outputView.write(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + outputView.write(b, off, len); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java new file mode 100644 index 0000000..0d03f9f --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SerializableFnAggregatorWrapper.java @@ -0,0 +1,91 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.transforms.Aggregator; +import com.google.cloud.dataflow.sdk.transforms.Combine; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.apache.flink.api.common.accumulators.Accumulator; + +import java.io.Serializable; + +/** + * Wrapper that wraps a {@link com.google.cloud.dataflow.sdk.transforms.Combine.CombineFn} + * in a Flink {@link org.apache.flink.api.common.accumulators.Accumulator} for using + * the function as an aggregator in a {@link com.google.cloud.dataflow.sdk.transforms.ParDo} + * operation. + */ +public class SerializableFnAggregatorWrapper<AI, AO> implements Aggregator<AI, AO>, Accumulator<AI, Serializable> { + + private AO aa; + private Combine.CombineFn<AI, ?, AO> combiner; + + public SerializableFnAggregatorWrapper(Combine.CombineFn<AI, ?, AO> combiner) { + this.combiner = combiner; + resetLocal(); + } + + @Override + @SuppressWarnings("unchecked") + public void add(AI value) { + this.aa = combiner.apply(ImmutableList.of((AI) aa, value)); + } + + @Override + public Serializable getLocalValue() { + return (Serializable) aa; + } + + @Override + public void resetLocal() { + this.aa = combiner.apply(ImmutableList.<AI>of()); + } + + @Override + @SuppressWarnings("unchecked") + public void merge(Accumulator<AI, Serializable> other) { + this.aa = combiner.apply(ImmutableList.of((AI) aa, (AI) other.getLocalValue())); + } + + @Override + public void addValue(AI value) { + add(value); + } + + @Override + public String getName() { + return "Aggregator :" + combiner.toString(); + } + + @Override + public Combine.CombineFn<AI, ?, AO> getCombineFn() { + return combiner; + } + + @Override + public Accumulator<AI, Serializable> clone() { + // copy it by merging + AO resultCopy = combiner.apply(Lists.newArrayList((AI) aa)); + SerializableFnAggregatorWrapper<AI, AO> result = new + SerializableFnAggregatorWrapper<>(combiner); + + result.aa = resultCopy; + return result; + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java new file mode 100644 index 0000000..d0423b9 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SinkOutputFormat.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.beam.runners.flink.translation.wrappers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.io.Sink; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.common.base.Preconditions; +import com.google.cloud.dataflow.sdk.transforms.Write; +import org.apache.flink.api.common.io.OutputFormat; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.util.AbstractID; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.lang.reflect.Field; + +/** + * Wrapper class to use generic Write.Bound transforms as sinks. + * @param <T> The type of the incoming records. + */ +public class SinkOutputFormat<T> implements OutputFormat<T> { + + private final Sink<T> sink; + + private transient PipelineOptions pipelineOptions; + + private Sink.WriteOperation<T, ?> writeOperation; + private Sink.Writer<T, ?> writer; + + private AbstractID uid = new AbstractID(); + + public SinkOutputFormat(Write.Bound<T> transform, PipelineOptions pipelineOptions) { + this.sink = extractSink(transform); + this.pipelineOptions = Preconditions.checkNotNull(pipelineOptions); + } + + private Sink<T> extractSink(Write.Bound<T> transform) { + // TODO possibly add a getter in the upstream + try { + Field sinkField = transform.getClass().getDeclaredField("sink"); + sinkField.setAccessible(true); + @SuppressWarnings("unchecked") + Sink<T> extractedSink = (Sink<T>) sinkField.get(transform); + return extractedSink; + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Could not acquire custom sink field.", e); + } + } + + @Override + public void configure(Configuration configuration) { + writeOperation = sink.createWriteOperation(pipelineOptions); + try { + writeOperation.initialize(pipelineOptions); + } catch (Exception e) { + throw new RuntimeException("Failed to initialize the write operation.", e); + } + } + + @Override + public void open(int taskNumber, int numTasks) throws IOException { + try { + writer = writeOperation.createWriter(pipelineOptions); + } catch (Exception e) { + throw new IOException("Couldn't create writer.", e); + } + try { + writer.open(uid + "-" + String.valueOf(taskNumber)); + } catch (Exception e) { + throw new IOException("Couldn't open writer.", e); + } + } + + @Override + public void writeRecord(T record) throws IOException { + try { + writer.write(record); + } catch (Exception e) { + throw new IOException("Couldn't write record.", e); + } + } + + @Override + public void close() throws IOException { + try { + writer.close(); + } catch (Exception e) { + throw new IOException("Couldn't close writer.", e); + } + } + + private void writeObject(ObjectOutputStream out) throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, pipelineOptions); + } + + private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + pipelineOptions = mapper.readValue(in, PipelineOptions.class); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java new file mode 100644 index 0000000..2d62416 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputFormat.java @@ -0,0 +1,164 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.dataflow.sdk.coders.Coder; +import com.google.cloud.dataflow.sdk.io.BoundedSource; +import com.google.cloud.dataflow.sdk.io.Source; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import org.apache.flink.api.common.io.InputFormat; +import org.apache.flink.api.common.io.statistics.BaseStatistics; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.io.InputSplit; +import org.apache.flink.core.io.InputSplitAssigner; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * A Flink {@link org.apache.flink.api.common.io.InputFormat} that wraps a + * Dataflow {@link com.google.cloud.dataflow.sdk.io.Source}. + */ +public class SourceInputFormat<T> implements InputFormat<T, SourceInputSplit<T>> { + private static final Logger LOG = LoggerFactory.getLogger(SourceInputFormat.class); + + private final BoundedSource<T> initialSource; + private transient PipelineOptions options; + + private BoundedSource.BoundedReader<T> reader = null; + private boolean reachedEnd = true; + + public SourceInputFormat(BoundedSource<T> initialSource, PipelineOptions options) { + this.initialSource = initialSource; + this.options = options; + } + + private void writeObject(ObjectOutputStream out) + throws IOException, ClassNotFoundException { + out.defaultWriteObject(); + ObjectMapper mapper = new ObjectMapper(); + mapper.writeValue(out, options); + } + + private void readObject(ObjectInputStream in) + throws IOException, ClassNotFoundException { + in.defaultReadObject(); + ObjectMapper mapper = new ObjectMapper(); + options = mapper.readValue(in, PipelineOptions.class); + } + + @Override + public void configure(Configuration configuration) {} + + @Override + public void open(SourceInputSplit<T> sourceInputSplit) throws IOException { + reader = ((BoundedSource<T>) sourceInputSplit.getSource()).createReader(options); + reachedEnd = false; + } + + @Override + public BaseStatistics getStatistics(BaseStatistics baseStatistics) throws IOException { + try { + final long estimatedSize = initialSource.getEstimatedSizeBytes(options); + + return new BaseStatistics() { + @Override + public long getTotalInputSize() { + return estimatedSize; + + } + + @Override + public long getNumberOfRecords() { + return BaseStatistics.NUM_RECORDS_UNKNOWN; + } + + @Override + public float getAverageRecordWidth() { + return BaseStatistics.AVG_RECORD_BYTES_UNKNOWN; + } + }; + } catch (Exception e) { + LOG.warn("Could not read Source statistics: {}", e); + } + + return null; + } + + @Override + @SuppressWarnings("unchecked") + public SourceInputSplit<T>[] createInputSplits(int numSplits) throws IOException { + long desiredSizeBytes; + try { + desiredSizeBytes = initialSource.getEstimatedSizeBytes(options) / numSplits; + List<? extends Source<T>> shards = initialSource.splitIntoBundles(desiredSizeBytes, + options); + List<SourceInputSplit<T>> splits = new ArrayList<>(); + int splitCount = 0; + for (Source<T> shard: shards) { + splits.add(new SourceInputSplit<>(shard, splitCount++)); + } + return splits.toArray(new SourceInputSplit[splits.size()]); + } catch (Exception e) { + throw new IOException("Could not create input splits from Source.", e); + } + } + + @Override + public InputSplitAssigner getInputSplitAssigner(final SourceInputSplit[] sourceInputSplits) { + return new InputSplitAssigner() { + private int index = 0; + private final SourceInputSplit[] splits = sourceInputSplits; + @Override + public InputSplit getNextInputSplit(String host, int taskId) { + if (index < splits.length) { + return splits[index++]; + } else { + return null; + } + } + }; + } + + + @Override + public boolean reachedEnd() throws IOException { + return reachedEnd; + } + + @Override + public T nextRecord(T t) throws IOException { + + reachedEnd = !reader.advance(); + if (!reachedEnd) { + return reader.getCurrent(); + } + return null; + } + + @Override + public void close() throws IOException { + reader.close(); + } +} http://git-wip-us.apache.org/repos/asf/incubator-beam/blob/51bec310/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java new file mode 100644 index 0000000..1b45ad7 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/SourceInputSplit.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015 Data Artisans GmbH + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.flink.translation.wrappers; + +import com.google.cloud.dataflow.sdk.io.Source; +import org.apache.flink.core.io.InputSplit; + +/** + * {@link org.apache.flink.core.io.InputSplit} for + * {@link org.apache.beam.runners.flink.translation.wrappers.SourceInputFormat}. We pass + * the sharded Source around in the input split because Sources simply split up into several + * Sources for sharding. This is different to how Flink creates a separate InputSplit from + * an InputFormat. + */ +public class SourceInputSplit<T> implements InputSplit { + + private Source<T> source; + private int splitNumber; + + public SourceInputSplit() { + } + + public SourceInputSplit(Source<T> source, int splitNumber) { + this.source = source; + this.splitNumber = splitNumber; + } + + @Override + public int getSplitNumber() { + return splitNumber; + } + + public Source<T> getSource() { + return source; + } + +}