http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java index d015c38..31e931c 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkBroadcastStateInternals.java @@ -97,92 +97,74 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { StateTag<? super K, T> address, final StateContext<?> context) { - return address.bind(new StateTag.StateBinder<K>() { + return address.bind( + new StateTag.StateBinder<K>() { - @Override - public <T> ValueState<T> bindValue( - StateTag<? super K, ValueState<T>> address, - Coder<T> coder) { + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, Coder<T> coder) { - return new FlinkBroadcastValueState<>(stateBackend, address, namespace, coder); - } + return new FlinkBroadcastValueState<>(stateBackend, address, namespace, coder); + } - @Override - public <T> BagState<T> bindBag( - StateTag<? super K, BagState<T>> address, - Coder<T> elemCoder) { + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - return new FlinkBroadcastBagState<>(stateBackend, address, namespace, elemCoder); - } - - @Override - public <T> SetState<T> bindSet( - StateTag<? super K, SetState<T>> address, - Coder<T> elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); - } + return new FlinkBroadcastBagState<>(stateBackend, address, namespace, elemCoder); + } - @Override - public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( - StateTag<? super K, MapState<KeyT, ValueT>> spec, - Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); - } + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> - bindCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, + Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } - return new FlinkCombiningState<>( - stateBackend, address, combineFn, namespace, accumCoder); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkKeyedCombiningState<>( - stateBackend, - address, - combineFn, - namespace, - accumCoder, - FlinkBroadcastStateInternals.this); - } + return new FlinkCombiningState<>( + stateBackend, address, combineFn, namespace, accumCoder); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkCombiningStateWithContext<>( - stateBackend, - address, - combineFn, - namespace, - accumCoder, - FlinkBroadcastStateInternals.this, - CombineContextFactory.createFromStateContext(context)); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + return new FlinkCombiningStateWithContext<>( + stateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkBroadcastStateInternals.this, + CombineContextFactory.createFromStateContext(context)); + } - @Override - public <W extends BoundedWindow> WatermarkHoldState bindWatermark( - StateTag<? super K, WatermarkHoldState> address, - TimestampCombiner timestampCombiner) { - throw new UnsupportedOperationException( - String.format("%s is not supported", WatermarkHoldState.class.getSimpleName())); - } - }); + @Override + public <W extends BoundedWindow> WatermarkHoldState bindWatermark( + StateTag<? super K, WatermarkHoldState> address, + TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException( + String.format("%s is not supported", WatermarkHoldState.class.getSimpleName())); + } + }); } /** @@ -587,13 +569,13 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { private final StateNamespace namespace; private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; - private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; private final FlinkBroadcastStateInternals<K> flinkStateInternals; FlinkKeyedCombiningState( DefaultOperatorStateBackend flinkStateBackend, StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, StateNamespace namespace, Coder<AccumT> accumCoder, FlinkBroadcastStateInternals<K> flinkStateInternals) { @@ -616,9 +598,9 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { try { AccumT current = readInternal(); if (current == null) { - current = combineFn.createAccumulator(flinkStateInternals.getKey()); + current = combineFn.createAccumulator(); } - current = combineFn.addInput(flinkStateInternals.getKey(), current, value); + current = combineFn.addInput(current, value); writeInternal(current); } catch (Exception e) { throw new RuntimeException("Error adding to state." , e); @@ -632,9 +614,7 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { if (current == null) { writeInternal(accum); } else { - current = combineFn.mergeAccumulators( - flinkStateInternals.getKey(), - Arrays.asList(current, accum)); + current = combineFn.mergeAccumulators(Arrays.asList(current, accum)); writeInternal(current); } } catch (Exception e) { @@ -653,7 +633,7 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators); + return combineFn.mergeAccumulators(accumulators); } @Override @@ -661,11 +641,9 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { try { AccumT accum = readInternal(); if (accum != null) { - return combineFn.extractOutput(flinkStateInternals.getKey(), accum); + return combineFn.extractOutput(accum); } else { - return combineFn.extractOutput( - flinkStateInternals.getKey(), - combineFn.createAccumulator(flinkStateInternals.getKey())); + return combineFn.extractOutput(combineFn.createAccumulator()); } } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -727,16 +705,14 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { private final StateNamespace namespace; private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; - private final CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn; + private final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn; private final FlinkBroadcastStateInternals<K> flinkStateInternals; private final CombineWithContext.Context context; FlinkCombiningStateWithContext( DefaultOperatorStateBackend flinkStateBackend, StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn, StateNamespace namespace, Coder<AccumT> accumCoder, FlinkBroadcastStateInternals<K> flinkStateInternals, @@ -761,9 +737,9 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { try { AccumT current = readInternal(); if (current == null) { - current = combineFn.createAccumulator(flinkStateInternals.getKey(), context); + current = combineFn.createAccumulator(context); } - current = combineFn.addInput(flinkStateInternals.getKey(), current, value, context); + current = combineFn.addInput(current, value, context); writeInternal(current); } catch (Exception e) { throw new RuntimeException("Error adding to state." , e); @@ -778,10 +754,7 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { if (current == null) { writeInternal(accum); } else { - current = combineFn.mergeAccumulators( - flinkStateInternals.getKey(), - Arrays.asList(current, accum), - context); + current = combineFn.mergeAccumulators(Arrays.asList(current, accum), context); writeInternal(current); } } catch (Exception e) { @@ -800,14 +773,14 @@ public class FlinkBroadcastStateInternals<K> implements StateInternals<K> { @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators, context); + return combineFn.mergeAccumulators(accumulators, context); } @Override public OutputT read() { try { AccumT accum = readInternal(); - return combineFn.extractOutput(flinkStateInternals.getKey(), accum, context); + return combineFn.extractOutput(accum, context); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); }
http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java index 2dd7c96..67d7966 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkKeyGroupStateInternals.java @@ -120,79 +120,66 @@ public class FlinkKeyGroupStateInternals<K> implements StateInternals<K> { StateTag<? super K, T> address, final StateContext<?> context) { - return address.bind(new StateTag.StateBinder<K>() { - - @Override - public <T> ValueState<T> bindValue( - StateTag<? super K, ValueState<T>> address, - Coder<T> coder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", ValueState.class.getSimpleName())); - } + return address.bind( + new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, Coder<T> coder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", ValueState.class.getSimpleName())); + } - @Override - public <T> BagState<T> bindBag( - StateTag<? super K, BagState<T>> address, - Coder<T> elemCoder) { + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - return new FlinkKeyGroupBagState<>(address, namespace, elemCoder); - } - - @Override - public <T> SetState<T> bindSet( - StateTag<? super K, SetState<T>> address, - Coder<T> elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); - } - - @Override - public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( - StateTag<? super K, MapState<KeyT, ValueT>> spec, - Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); - } + return new FlinkKeyGroupBagState<>(address, namespace, elemCoder); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> - bindCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException("bindCombiningValue is not supported."); - } + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException("bindKeyedCombiningValue is not supported."); + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, + Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException("bindCombiningValue is not supported."); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException( - "bindKeyedCombiningValueWithContext is not supported."); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException( + "bindCombiningValueWithContext is not supported."); + } - @Override - public <W extends BoundedWindow> WatermarkHoldState bindWatermark( - StateTag<? super K, WatermarkHoldState> address, - TimestampCombiner timestampCombiner) { - throw new UnsupportedOperationException( - String.format("%s is not supported", CombiningState.class.getSimpleName())); - } - }); + @Override + public <W extends BoundedWindow> WatermarkHoldState bindWatermark( + StateTag<? super K, WatermarkHoldState> address, + TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException( + String.format("%s is not supported", CombiningState.class.getSimpleName())); + } + }); } /** http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkSplitStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkSplitStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkSplitStateInternals.java index 17ea62a..ef6c3b2 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkSplitStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkSplitStateInternals.java @@ -80,79 +80,66 @@ public class FlinkSplitStateInternals<K> implements StateInternals<K> { StateTag<? super K, T> address, final StateContext<?> context) { - return address.bind(new StateTag.StateBinder<K>() { - - @Override - public <T> ValueState<T> bindValue( - StateTag<? super K, ValueState<T>> address, - Coder<T> coder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", ValueState.class.getSimpleName())); - } + return address.bind( + new StateTag.StateBinder<K>() { + + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, Coder<T> coder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", ValueState.class.getSimpleName())); + } - @Override - public <T> BagState<T> bindBag( - StateTag<? super K, BagState<T>> address, - Coder<T> elemCoder) { + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - return new FlinkSplitBagState<>(stateBackend, address, namespace, elemCoder); - } - - @Override - public <T> SetState<T> bindSet( - StateTag<? super K, SetState<T>> address, - Coder<T> elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); - } - - @Override - public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( - StateTag<? super K, MapState<KeyT, ValueT>> spec, - Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); - } + return new FlinkSplitBagState<>(stateBackend, address, namespace, elemCoder); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> - bindCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException("bindCombiningValue is not supported."); - } + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException("bindKeyedCombiningValue is not supported."); + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, + Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException("bindCombiningValue is not supported."); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn) { - throw new UnsupportedOperationException( - "bindKeyedCombiningValueWithContext is not supported."); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + throw new UnsupportedOperationException( + "bindCombiningValueWithContext is not supported."); + } - @Override - public <W extends BoundedWindow> WatermarkHoldState bindWatermark( - StateTag<? super K, WatermarkHoldState> address, - TimestampCombiner timestampCombiner) { - throw new UnsupportedOperationException( - String.format("%s is not supported", CombiningState.class.getSimpleName())); - } - }); + @Override + public <W extends BoundedWindow> WatermarkHoldState bindWatermark( + StateTag<? super K, WatermarkHoldState> address, + TimestampCombiner timestampCombiner) { + throw new UnsupportedOperationException( + String.format("%s is not supported", CombiningState.class.getSimpleName())); + } + }); } private static class FlinkSplitBagState<K, T> implements BagState<T> { http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index 878c914..c99d085 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -106,93 +106,75 @@ public class FlinkStateInternals<K> implements StateInternals<K> { StateTag<? super K, T> address, final StateContext<?> context) { - return address.bind(new StateTag.StateBinder<K>() { + return address.bind( + new StateTag.StateBinder<K>() { - @Override - public <T> ValueState<T> bindValue( - StateTag<? super K, ValueState<T>> address, - Coder<T> coder) { + @Override + public <T> ValueState<T> bindValue( + StateTag<? super K, ValueState<T>> address, Coder<T> coder) { - return new FlinkValueState<>(flinkStateBackend, address, namespace, coder); - } - - @Override - public <T> BagState<T> bindBag( - StateTag<? super K, BagState<T>> address, - Coder<T> elemCoder) { + return new FlinkValueState<>(flinkStateBackend, address, namespace, coder); + } - return new FlinkBagState<>(flinkStateBackend, address, namespace, elemCoder); - } + @Override + public <T> BagState<T> bindBag( + StateTag<? super K, BagState<T>> address, Coder<T> elemCoder) { - @Override - public <T> SetState<T> bindSet( - StateTag<? super K, SetState<T>> address, - Coder<T> elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); - } + return new FlinkBagState<>(flinkStateBackend, address, namespace, elemCoder); + } - @Override - public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( - StateTag<? super K, MapState<KeyT, ValueT>> spec, - Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); - } + @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> - bindCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, + Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } - return new FlinkCombiningState<>( - flinkStateBackend, address, combineFn, namespace, accumCoder); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValue( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn) { - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkKeyedCombiningState<>( - flinkStateBackend, - address, - combineFn, - namespace, - accumCoder, - FlinkStateInternals.this); - } + return new FlinkCombiningState<>( + flinkStateBackend, address, combineFn, namespace, accumCoder); + } - @Override - public <InputT, AccumT, OutputT> - CombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValueWithContext( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn) { - return new FlinkCombiningStateWithContext<>( - flinkStateBackend, - address, - combineFn, - namespace, - accumCoder, - FlinkStateInternals.this, - CombineContextFactory.createFromStateContext(context)); - } + @Override + public <InputT, AccumT, OutputT> + CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext( + StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, + Coder<AccumT> accumCoder, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + return new FlinkCombiningStateWithContext<>( + flinkStateBackend, + address, + combineFn, + namespace, + accumCoder, + FlinkStateInternals.this, + CombineContextFactory.createFromStateContext(context)); + } - @Override - public <W extends BoundedWindow> WatermarkHoldState bindWatermark( - StateTag<? super K, WatermarkHoldState> address, - TimestampCombiner timestampCombiner) { + @Override + public <W extends BoundedWindow> WatermarkHoldState bindWatermark( + StateTag<? super K, WatermarkHoldState> address, + TimestampCombiner timestampCombiner) { - return new FlinkWatermarkHoldState<>( - flinkStateBackend, FlinkStateInternals.this, address, namespace, timestampCombiner); - } - }); + return new FlinkWatermarkHoldState<>( + flinkStateBackend, FlinkStateInternals.this, address, namespace, timestampCombiner); + } + }); } private static class FlinkValueState<K, T> implements ValueState<T> { @@ -566,7 +548,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { private final StateNamespace namespace; private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; - private final Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + private final Combine.CombineFn<InputT, AccumT, OutputT> combineFn; private final ValueStateDescriptor<AccumT> flinkStateDescriptor; private final KeyedStateBackend<ByteBuffer> flinkStateBackend; private final FlinkStateInternals<K> flinkStateInternals; @@ -574,7 +556,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { FlinkKeyedCombiningState( KeyedStateBackend<ByteBuffer> flinkStateBackend, StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Combine.KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn, + Combine.CombineFn<InputT, AccumT, OutputT> combineFn, StateNamespace namespace, Coder<AccumT> accumCoder, FlinkStateInternals<K> flinkStateInternals) { @@ -606,9 +588,9 @@ public class FlinkStateInternals<K> implements StateInternals<K> { AccumT current = state.value(); if (current == null) { - current = combineFn.createAccumulator(flinkStateInternals.getKey()); + current = combineFn.createAccumulator(); } - current = combineFn.addInput(flinkStateInternals.getKey(), current, value); + current = combineFn.addInput(current, value); state.update(current); } catch (Exception e) { throw new RuntimeException("Error adding to state." , e); @@ -628,9 +610,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { if (current == null) { state.update(accum); } else { - current = combineFn.mergeAccumulators( - flinkStateInternals.getKey(), - Lists.newArrayList(current, accum)); + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum)); state.update(current); } } catch (Exception e) { @@ -652,7 +632,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators); + return combineFn.mergeAccumulators(accumulators); } @Override @@ -666,11 +646,9 @@ public class FlinkStateInternals<K> implements StateInternals<K> { AccumT accum = state.value(); if (accum != null) { - return combineFn.extractOutput(flinkStateInternals.getKey(), accum); + return combineFn.extractOutput(accum); } else { - return combineFn.extractOutput( - flinkStateInternals.getKey(), - combineFn.createAccumulator(flinkStateInternals.getKey())); + return combineFn.extractOutput(combineFn.createAccumulator()); } } catch (Exception e) { throw new RuntimeException("Error reading state.", e); @@ -741,8 +719,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { private final StateNamespace namespace; private final StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address; - private final CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn; + private final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn; private final ValueStateDescriptor<AccumT> flinkStateDescriptor; private final KeyedStateBackend<ByteBuffer> flinkStateBackend; private final FlinkStateInternals<K> flinkStateInternals; @@ -751,8 +728,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { FlinkCombiningStateWithContext( KeyedStateBackend<ByteBuffer> flinkStateBackend, StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - CombineWithContext.KeyedCombineFnWithContext< - ? super K, InputT, AccumT, OutputT> combineFn, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn, StateNamespace namespace, Coder<AccumT> accumCoder, FlinkStateInternals<K> flinkStateInternals, @@ -786,9 +762,9 @@ public class FlinkStateInternals<K> implements StateInternals<K> { AccumT current = state.value(); if (current == null) { - current = combineFn.createAccumulator(flinkStateInternals.getKey(), context); + current = combineFn.createAccumulator(context); } - current = combineFn.addInput(flinkStateInternals.getKey(), current, value, context); + current = combineFn.addInput(current, value, context); state.update(current); } catch (Exception e) { throw new RuntimeException("Error adding to state." , e); @@ -808,10 +784,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { if (current == null) { state.update(accum); } else { - current = combineFn.mergeAccumulators( - flinkStateInternals.getKey(), - Lists.newArrayList(current, accum), - context); + current = combineFn.mergeAccumulators(Lists.newArrayList(current, accum), context); state.update(current); } } catch (Exception e) { @@ -833,7 +806,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(flinkStateInternals.getKey(), accumulators, context); + return combineFn.mergeAccumulators(accumulators, context); } @Override @@ -846,7 +819,7 @@ public class FlinkStateInternals<K> implements StateInternals<K> { flinkStateDescriptor); AccumT accum = state.value(); - return combineFn.extractOutput(flinkStateInternals.getKey(), accum, context); + return combineFn.extractOutput(accum, context); } catch (Exception e) { throw new RuntimeException("Error reading state.", e); } http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java index c967521..cdc23ff 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java @@ -31,8 +31,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.transforms.Combine.CombineFn; -import org.apache.beam.sdk.transforms.Combine.KeyedCombineFn; -import org.apache.beam.sdk.transforms.CombineWithContext.KeyedCombineFnWithContext; +import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CombineFnUtil; @@ -142,27 +141,17 @@ class SparkStateInternals<K> implements StateInternals<K> { StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, CombineFn<InputT, AccumT, OutputT> combineFn) { - return new SparkCombiningState<>(namespace, address, accumCoder, key, - combineFn.<K>asKeyedFn()); + return new SparkCombiningState<>(namespace, address, accumCoder, combineFn); } @Override public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> - bindKeyedCombiningValue( + bindCombiningValueWithContext( StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, - KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { - return new SparkCombiningState<>(namespace, address, accumCoder, key, combineFn); - } - - @Override - public <InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> - bindKeyedCombiningValueWithContext( - StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, - Coder<AccumT> accumCoder, - KeyedCombineFnWithContext<? super K, InputT, AccumT, OutputT> combineFn) { - return new SparkCombiningState<>(namespace, address, accumCoder, key, - CombineFnUtil.bindContext(combineFn, c)); + CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + return new SparkCombiningState<>( + namespace, address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); } @Override @@ -307,17 +296,14 @@ class SparkStateInternals<K> implements StateInternals<K> { extends AbstractState<AccumT> implements CombiningState<InputT, AccumT, OutputT> { - private final K key; - private final KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn; + private final CombineFn<InputT, AccumT, OutputT> combineFn; private SparkCombiningState( StateNamespace namespace, StateTag<? super K, CombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> coder, - K key, - KeyedCombineFn<? super K, InputT, AccumT, OutputT> combineFn) { + CombineFn<InputT, AccumT, OutputT> combineFn) { super(namespace, address, coder); - this.key = key; this.combineFn = combineFn; } @@ -328,13 +314,13 @@ class SparkStateInternals<K> implements StateInternals<K> { @Override public OutputT read() { - return combineFn.extractOutput(key, getAccum()); + return combineFn.extractOutput(getAccum()); } @Override public void add(InputT input) { AccumT accum = getAccum(); - combineFn.addInput(key, accum, input); + combineFn.addInput(accum, input); writeValue(accum); } @@ -342,7 +328,7 @@ class SparkStateInternals<K> implements StateInternals<K> { public AccumT getAccum() { AccumT accum = readValue(); if (accum == null) { - accum = combineFn.createAccumulator(key); + accum = combineFn.createAccumulator(); } return accum; } @@ -363,13 +349,13 @@ class SparkStateInternals<K> implements StateInternals<K> { @Override public void addAccum(AccumT accum) { - accum = combineFn.mergeAccumulators(key, Arrays.asList(getAccum(), accum)); + accum = combineFn.mergeAccumulators(Arrays.asList(getAccum(), accum)); writeValue(accum); } @Override public AccumT mergeAccumulators(Iterable<AccumT> accumulators) { - return combineFn.mergeAccumulators(key, accumulators); + return combineFn.mergeAccumulators(accumulators); } } http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java index 66c03bc..58db8e4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/SparkKeyedCombineFn.java @@ -41,14 +41,14 @@ import org.joda.time.Instant; /** - * A {@link org.apache.beam.sdk.transforms.CombineFnBase.PerKeyCombineFn} + * A {@link org.apache.beam.sdk.transforms.CombineFnBase.GlobalCombineFn} * with a {@link org.apache.beam.sdk.transforms.CombineWithContext.Context} for the SparkRunner. */ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstractCombineFn { - private final CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn; + private final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn; public SparkKeyedCombineFn( - CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn, + CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn, SparkRuntimeContext runtimeContext, Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs, WindowingStrategy<?, ?> windowingStrategy) { @@ -59,8 +59,7 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra /** Applying the combine function directly on a key's grouped values - post grouping. */ public OutputT apply(WindowedValue<KV<K, Iterable<InputT>>> windowedKv) { // apply combine function on grouped values. - return combineFn.apply(windowedKv.getValue().getKey(), windowedKv.getValue().getValue(), - ctxtForInput(windowedKv)); + return combineFn.apply(windowedKv.getValue().getValue(), ctxtForInput(windowedKv)); } /** @@ -83,8 +82,8 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra // first create the accumulator and accumulate first input. K key = currentInput.getValue().getKey(); - AccumT accumulator = combineFn.createAccumulator(key, ctxtForInput(currentInput)); - accumulator = combineFn.addInput(key, accumulator, currentInput.getValue().getValue(), + AccumT accumulator = combineFn.createAccumulator(ctxtForInput(currentInput)); + accumulator = combineFn.addInput(accumulator, currentInput.getValue().getValue(), ctxtForInput(currentInput)); // keep track of the timestamps assigned by the TimestampCombiner. @@ -114,7 +113,7 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra currentWindow = merge((IntervalWindow) currentWindow, (IntervalWindow) nextWindow); } // keep accumulating and carry on ;-) - accumulator = combineFn.addInput(key, accumulator, nextValue.getValue().getValue(), + accumulator = combineFn.addInput(accumulator, nextValue.getValue().getValue(), ctxtForInput(nextValue)); windowTimestamp = timestampCombiner.combine( @@ -128,8 +127,8 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra output.add(WindowedValue.of(KV.of(key, accumulator), windowTimestamp, currentWindow, PaneInfo.NO_FIRING)); // re-init accumulator, window and timestamp. - accumulator = combineFn.createAccumulator(key, ctxtForInput(nextValue)); - accumulator = combineFn.addInput(key, accumulator, nextValue.getValue().getValue(), + accumulator = combineFn.createAccumulator(ctxtForInput(nextValue)); + accumulator = combineFn.addInput(accumulator, nextValue.getValue().getValue(), ctxtForInput(nextValue)); currentWindow = nextWindow; windowTimestamp = @@ -233,7 +232,7 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra WindowedValue<KV<K, Iterable<AccumT>>> preMergeWindowedValue = WindowedValue.of( KV.of(key, accumsToMerge), mergedTimestamp, currentWindow, PaneInfo.NO_FIRING); // applying the actual combiner onto the accumulators. - AccumT accumulated = combineFn.mergeAccumulators(key, accumsToMerge, + AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge, ctxtForInput(preMergeWindowedValue)); WindowedValue<KV<K, AccumT>> postMergeWindowedValue = preMergeWindowedValue.withValue(KV.of(key, accumulated)); @@ -254,7 +253,7 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra Iterable<AccumT> accumsToMerge = Iterables.unmodifiableIterable(currentWindowAccumulators); WindowedValue<KV<K, Iterable<AccumT>>> preMergeWindowedValue = WindowedValue.of( KV.of(key, accumsToMerge), mergedTimestamp, currentWindow, PaneInfo.NO_FIRING); - AccumT accumulated = combineFn.mergeAccumulators(key, accumsToMerge, + AccumT accumulated = combineFn.mergeAccumulators(accumsToMerge, ctxtForInput(preMergeWindowedValue)); WindowedValue<KV<K, AccumT>> postMergeWindowedValue = preMergeWindowedValue.withValue(KV.of(key, accumulated)); @@ -272,9 +271,8 @@ public class SparkKeyedCombineFn<K, InputT, AccumT, OutputT> extends SparkAbstra if (wkva == null) { return null; } - K key = wkva.getValue().getKey(); AccumT accumulator = wkva.getValue().getValue(); - return wkva.withValue(combineFn.extractOutput(key, accumulator, ctxtForInput(wkva))); + return wkva.withValue(combineFn.extractOutput(accumulator, ctxtForInput(wkva))); } }); } http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index c2a8b06..d249e78 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -29,7 +29,6 @@ import com.google.common.collect.Maps; import java.util.Collection; import java.util.Collections; import java.util.Map; -import java.util.Map.Entry; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.runners.spark.aggregators.AggregatorsAccumulator; import org.apache.beam.runners.spark.aggregators.NamedAggregators; @@ -165,8 +164,8 @@ public final class TransformTranslator { Combine.GroupedValues<K, InputT, OutputT> transform, EvaluationContext context) { @SuppressWarnings("unchecked") - CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT> combineFn = - (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT>) + CombineWithContext.CombineFnWithContext<InputT, ?, OutputT> combineFn = + (CombineWithContext.CombineFnWithContext<InputT, ?, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn()); final SparkKeyedCombineFn<K, InputT, ?, OutputT> sparkCombineFn = new SparkKeyedCombineFn<>(combineFn, context.getRuntimeContext(), @@ -282,16 +281,15 @@ public final class TransformTranslator { return new TransformEvaluator<Combine.PerKey<K, InputT, OutputT>>() { @Override public void evaluate( - Combine.PerKey<K, InputT, OutputT> transform, - EvaluationContext context) { + Combine.PerKey<K, InputT, OutputT> transform, EvaluationContext context) { final PCollection<KV<K, InputT>> input = context.getInput(transform); // serializable arguments to pass. @SuppressWarnings("unchecked") final KvCoder<K, InputT> inputCoder = (KvCoder<K, InputT>) context.getInput(transform).getCoder(); @SuppressWarnings("unchecked") - final CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT> combineFn = - (CombineWithContext.KeyedCombineFnWithContext<K, InputT, AccumT, OutputT>) + final CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT> combineFn = + (CombineWithContext.CombineFnWithContext<InputT, AccumT, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn()); final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); final SparkRuntimeContext runtimeContext = context.getRuntimeContext(); @@ -301,8 +299,9 @@ public final class TransformTranslator { new SparkKeyedCombineFn<>(combineFn, runtimeContext, sideInputs, windowingStrategy); final Coder<AccumT> vaCoder; try { - vaCoder = combineFn.getAccumulatorCoder(runtimeContext.getCoderRegistry(), - inputCoder.getKeyCoder(), inputCoder.getValueCoder()); + vaCoder = + combineFn.getAccumulatorCoder( + runtimeContext.getCoderRegistry(), inputCoder.getValueCoder()); } catch (CannotProvideCoderException e) { throw new IllegalStateException("Could not determine coder for accumulator", e); } @@ -312,19 +311,28 @@ public final class TransformTranslator { ((BoundedDataset<KV<K, InputT>>) context.borrowDataset(transform)).getRDD(); JavaPairRDD<K, Iterable<WindowedValue<KV<K, AccumT>>>> accumulatePerKey = - GroupCombineFunctions.combinePerKey(inRdd, sparkCombineFn, inputCoder.getKeyCoder(), - inputCoder.getValueCoder(), vaCoder, windowingStrategy); + GroupCombineFunctions.combinePerKey( + inRdd, + sparkCombineFn, + inputCoder.getKeyCoder(), + inputCoder.getValueCoder(), + vaCoder, + windowingStrategy); JavaRDD<WindowedValue<KV<K, OutputT>>> outRdd = - accumulatePerKey.flatMapValues(new Function<Iterable<WindowedValue<KV<K, AccumT>>>, - Iterable<WindowedValue<OutputT>>>() { - @Override - public Iterable<WindowedValue<OutputT>> call( - Iterable<WindowedValue<KV<K, AccumT>>> iter) throws Exception { + accumulatePerKey + .flatMapValues( + new Function< + Iterable<WindowedValue<KV<K, AccumT>>>, + Iterable<WindowedValue<OutputT>>>() { + @Override + public Iterable<WindowedValue<OutputT>> call( + Iterable<WindowedValue<KV<K, AccumT>>> iter) throws Exception { return sparkCombineFn.extractOutput(iter); } - }).map(TranslationUtils.<K, WindowedValue<OutputT>>fromPairFunction()) - .map(TranslationUtils.<K, OutputT>toKVByWindowInValue()); + }) + .map(TranslationUtils.<K, WindowedValue<OutputT>>fromPairFunction()) + .map(TranslationUtils.<K, OutputT>toKVByWindowInValue()); context.putDataset(transform, new BoundedDataset<>(outRdd)); } http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 26f0ade..9af4af2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -327,8 +327,8 @@ public final class StreamingTransformTranslator { context.getInput(transform); final WindowingStrategy<?, ?> windowingStrategy = input.getWindowingStrategy(); @SuppressWarnings("unchecked") - final CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT> fn = - (CombineWithContext.KeyedCombineFnWithContext<K, InputT, ?, OutputT>) + final CombineWithContext.CombineFnWithContext<InputT, ?, OutputT> fn = + (CombineWithContext.CombineFnWithContext<InputT, ?, OutputT>) CombineFnUtil.toFnWithContext(transform.getFn()); @SuppressWarnings("unchecked") http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java ---------------------------------------------------------------------- diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java index ff43fa6..ce52b90 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/SparkRunnerDebuggerTest.java @@ -93,10 +93,9 @@ public class SparkRunnerDebuggerTest { final String expectedPipeline = "sparkContext.parallelize(Arrays.asList(...))\n" + "_.mapPartitions(new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n" - + "_.combineByKey(..., new org.apache.beam.sdk.transforms" - + ".Combine$CombineFn$KeyIgnoringCombineFn(), ...)\n" + + "_.combineByKey(..., new org.apache.beam.sdk.transforms.Count$CountFn(), ...)\n" + "_.groupByKey()\n" - + "_.map(new org.apache.beam.sdk.transforms.Combine$CombineFn$KeyIgnoringCombineFn())\n" + + "_.map(new org.apache.beam.sdk.transforms.Sum$SumLongFn())\n" + "_.mapPartitions(new org.apache.beam.runners.spark" + ".SparkRunnerDebuggerTest$PlusOne())\n" + "sparkContext.union(...)\n" @@ -145,7 +144,7 @@ public class SparkRunnerDebuggerTest { + "SparkRunnerDebuggerTest$FormatKVFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Distinct$2())\n" + "_.groupByKey()\n" - + "_.map(new org.apache.beam.sdk.transforms.Combine$CombineFn$KeyIgnoringCombineFn())\n" + + "_.map(new org.apache.beam.sdk.transforms.Combine$IterableCombineFn())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.Keys$1())\n" + "_.mapPartitions(new org.apache.beam.sdk.transforms.WithKeys$2())\n" + "_.<org.apache.beam.sdk.io.kafka.AutoValue_KafkaIO_Write>"; http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml ---------------------------------------------------------------------- diff --git a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml index cf7d668..d03fbf3 100644 --- a/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml +++ b/sdks/java/build-tools/src/main/resources/beam/findbugs-filter.xml @@ -361,7 +361,7 @@ <!--[BEAM-413] Test for floating point equality--> </Match> <Match> - <Class name="org.apache.beam.sdk.util.CombineFnUtil$NonSerializableBoundedKeyedCombineFn"/> + <Class name="org.apache.beam.sdk.util.CombineFnUtil$NonSerializableBoundedCombineFn"/> <Field name="context"/> <Bug pattern="SE_BAD_FIELD"/> <!--[BEAM-419] Non-transient non-serializable instance field in serializable class--> http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java index ed3a253..5432f09 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateQuantiles.java @@ -155,9 +155,7 @@ public class ApproximateQuantiles { public static <K, V, ComparatorT extends Comparator<V> & Serializable> PTransform<PCollection<KV<K, V>>, PCollection<KV<K, List<V>>>> perKey(int numQuantiles, ComparatorT compareFn) { - return Combine.perKey( - ApproximateQuantilesCombineFn.create(numQuantiles, compareFn) - .<K>asKeyedFn()); + return Combine.perKey(ApproximateQuantilesCombineFn.create(numQuantiles, compareFn)); } /** @@ -173,9 +171,7 @@ public class ApproximateQuantiles { public static <K, V extends Comparable<V>> PTransform<PCollection<KV<K, V>>, PCollection<KV<K, List<V>>>> perKey(int numQuantiles) { - return Combine.perKey( - ApproximateQuantilesCombineFn.<V>create(numQuantiles) - .<K>asKeyedFn()); + return Combine.perKey(ApproximateQuantilesCombineFn.<V>create(numQuantiles)); } http://git-wip-us.apache.org/repos/asf/beam/blob/7e04924e/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateUnique.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateUnique.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateUnique.java index 33820e0..5d38206 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateUnique.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ApproximateUnique.java @@ -281,8 +281,7 @@ public class ApproximateUnique { final Coder<V> coder = ((KvCoder<K, V>) inputCoder).getValueCoder(); return input.apply( - Combine.perKey(new ApproximateUniqueCombineFn<>( - sampleSize, coder).<K>asKeyedFn())); + Combine.<K, V, Long>perKey(new ApproximateUniqueCombineFn<>(sampleSize, coder))); } @Override