[BEAM-1483] Support SetState in Flink runner and fix MapState to be consistent with InMemoryStateInternals.
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/10b166b3 Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/10b166b3 Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/10b166b3 Branch: refs/heads/master Commit: 10b166b355a03daeae78dd1e71016fc72805939d Parents: 4c36508 Author: JingsongLi <lzljs3620...@aliyun.com> Authored: Wed Jun 7 14:40:30 2017 +0800 Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com> Committed: Tue Jun 13 11:35:17 2017 +0200 ---------------------------------------------------------------------- runners/flink/pom.xml | 1 - .../streaming/state/FlinkStateInternals.java | 227 +++++++++++++++---- .../streaming/FlinkStateInternalsTest.java | 17 -- 3 files changed, 182 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/pom.xml ---------------------------------------------------------------------- diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml index a5b8203..339aa8e 100644 --- a/runners/flink/pom.xml +++ b/runners/flink/pom.xml @@ -91,7 +91,6 @@ <excludedGroups> org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders, org.apache.beam.sdk.testing.LargeKeys$Above100MB, - org.apache.beam.sdk.testing.UsesSetState, org.apache.beam.sdk.testing.UsesCommittedMetrics, org.apache.beam.sdk.testing.UsesTestStream, org.apache.beam.sdk.testing.UsesSplittableParDo http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java index d8771de..a0b015b 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.flink.translation.wrappers.streaming.state; +import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import java.nio.ByteBuffer; import java.util.Collections; @@ -33,6 +34,7 @@ import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; @@ -48,6 +50,7 @@ import org.apache.beam.sdk.util.CombineContextFactory; import org.apache.flink.api.common.state.ListStateDescriptor; import org.apache.flink.api.common.state.MapStateDescriptor; import org.apache.flink.api.common.state.ValueStateDescriptor; +import org.apache.flink.api.common.typeutils.base.BooleanSerializer; import org.apache.flink.api.common.typeutils.base.StringSerializer; import org.apache.flink.runtime.state.KeyedStateBackend; import org.joda.time.Instant; @@ -127,8 +130,8 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public <T> SetState<T> bindSet( StateTag<SetState<T>> address, Coder<T> elemCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", SetState.class.getSimpleName())); + return new FlinkSetState<>( + flinkStateBackend, address, namespace, elemCoder); } @Override @@ -875,24 +878,15 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public ReadableState<ValueT> get(final KeyT input) { - return new ReadableState<ValueT>() { - @Override - public ValueT read() { - try { - return flinkStateBackend.getPartitionedState( + try { + return ReadableStates.immediate( + flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, - flinkStateDescriptor).get(input); - } catch (Exception e) { - throw new RuntimeException("Error get from state.", e); - } - } - - @Override - public ReadableState<ValueT> readLater() { - return this; - } - }; + flinkStateDescriptor).get(input)); + } catch (Exception e) { + throw new RuntimeException("Error get from state.", e); + } } @Override @@ -909,32 +903,22 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT value) { - return new ReadableState<ValueT>() { - @Override - public ValueT read() { - try { - ValueT current = flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).get(key); - - if (current == null) { - flinkStateBackend.getPartitionedState( - namespace.stringKey(), - StringSerializer.INSTANCE, - flinkStateDescriptor).put(key, value); - } - return current; - } catch (Exception e) { - throw new RuntimeException("Error put kv to state.", e); - } - } + try { + ValueT current = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(key); - @Override - public ReadableState<ValueT> readLater() { - return this; + if (current == null) { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(key, value); } - }; + return ReadableStates.immediate(current); + } catch (Exception e) { + throw new RuntimeException("Error put kv to state.", e); + } } @Override @@ -955,10 +939,11 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<KeyT> read() { try { - return flinkStateBackend.getPartitionedState( + Iterable<KeyT> result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).keys(); + return result != null ? result : Collections.<KeyT>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state keys.", e); } @@ -977,10 +962,11 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<ValueT> read() { try { - return flinkStateBackend.getPartitionedState( + Iterable<ValueT> result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).values(); + return result != null ? result : Collections.<ValueT>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state values.", e); } @@ -999,10 +985,11 @@ public class FlinkStateInternals<K> implements StateInternals { @Override public Iterable<Map.Entry<KeyT, ValueT>> read() { try { - return flinkStateBackend.getPartitionedState( + Iterable<Map.Entry<KeyT, ValueT>> result = flinkStateBackend.getPartitionedState( namespace.stringKey(), StringSerializer.INSTANCE, flinkStateDescriptor).entries(); + return result != null ? result : Collections.<Map.Entry<KeyT, ValueT>>emptyList(); } catch (Exception e) { throw new RuntimeException("Error get map state entries.", e); } @@ -1050,4 +1037,154 @@ public class FlinkStateInternals<K> implements StateInternals { } } + private static class FlinkSetState<T> implements SetState<T> { + + private final StateNamespace namespace; + private final StateTag<SetState<T>> address; + private final MapStateDescriptor<T, Boolean> flinkStateDescriptor; + private final KeyedStateBackend<ByteBuffer> flinkStateBackend; + + FlinkSetState( + KeyedStateBackend<ByteBuffer> flinkStateBackend, + StateTag<SetState<T>> address, + StateNamespace namespace, + Coder<T> coder) { + this.namespace = namespace; + this.address = address; + this.flinkStateBackend = flinkStateBackend; + this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(), + new CoderTypeSerializer<>(coder), new BooleanSerializer()); + } + + @Override + public ReadableState<Boolean> contains(final T t) { + try { + Boolean result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).get(t); + return ReadableStates.immediate(result != null ? result : false); + } catch (Exception e) { + throw new RuntimeException("Error contains value from state.", e); + } + } + + @Override + public ReadableState<Boolean> addIfAbsent(final T t) { + try { + org.apache.flink.api.common.state.MapState<T, Boolean> state = + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor); + boolean alreadyContained = state.contains(t); + if (!alreadyContained) { + state.put(t, true); + } + return ReadableStates.immediate(!alreadyContained); + } catch (Exception e) { + throw new RuntimeException("Error addIfAbsent value to state.", e); + } + } + + @Override + public void remove(T t) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).remove(t); + } catch (Exception e) { + throw new RuntimeException("Error remove value to state.", e); + } + } + + @Override + public SetState<T> readLater() { + return this; + } + + @Override + public void add(T value) { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).put(value, true); + } catch (Exception e) { + throw new RuntimeException("Error add value to state.", e); + } + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + try { + Iterable<T> result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result == null || Iterables.isEmpty(result); + } catch (Exception e) { + throw new RuntimeException("Error isEmpty from state.", e); + } + } + + @Override + public ReadableState<Boolean> readLater() { + return this; + } + }; + } + + @Override + public Iterable<T> read() { + try { + Iterable<T> result = flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).keys(); + return result != null ? result : Collections.<T>emptyList(); + } catch (Exception e) { + throw new RuntimeException("Error read from state.", e); + } + } + + @Override + public void clear() { + try { + flinkStateBackend.getPartitionedState( + namespace.stringKey(), + StringSerializer.INSTANCE, + flinkStateDescriptor).clear(); + } catch (Exception e) { + throw new RuntimeException("Error clearing state.", e); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + FlinkSetState<?> that = (FlinkSetState<?>) o; + + return namespace.equals(that.namespace) && address.equals(that.address); + + } + + @Override + public int hashCode() { + int result = namespace.hashCode(); + result = 31 * result + address.hashCode(); + return result; + } + } + } http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java index e7564ec..b8d41de 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java @@ -63,21 +63,4 @@ public class FlinkStateInternalsTest extends StateInternalsTest { } } - ///////////////////////// Unsupported tests \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\ - - @Override - public void testSet() {} - - @Override - public void testSetIsEmpty() {} - - @Override - public void testMergeSetIntoSource() {} - - @Override - public void testMergeSetIntoNewNamespace() {} - - @Override - public void testMap() {} - }