[BEAM-1483] Support SetState in Flink runner and fix MapState to be consistent 
with InMemoryStateInternals.


Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/10b166b3
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/10b166b3
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/10b166b3

Branch: refs/heads/master
Commit: 10b166b355a03daeae78dd1e71016fc72805939d
Parents: 4c36508
Author: JingsongLi <lzljs3620...@aliyun.com>
Authored: Wed Jun 7 14:40:30 2017 +0800
Committer: Aljoscha Krettek <aljoscha.kret...@gmail.com>
Committed: Tue Jun 13 11:35:17 2017 +0200

----------------------------------------------------------------------
 runners/flink/pom.xml                           |   1 -
 .../streaming/state/FlinkStateInternals.java    | 227 +++++++++++++++----
 .../streaming/FlinkStateInternalsTest.java      |  17 --
 3 files changed, 182 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/pom.xml
----------------------------------------------------------------------
diff --git a/runners/flink/pom.xml b/runners/flink/pom.xml
index a5b8203..339aa8e 100644
--- a/runners/flink/pom.xml
+++ b/runners/flink/pom.xml
@@ -91,7 +91,6 @@
                   <excludedGroups>
                     org.apache.beam.sdk.testing.FlattenWithHeterogeneousCoders,
                     org.apache.beam.sdk.testing.LargeKeys$Above100MB,
-                    org.apache.beam.sdk.testing.UsesSetState,
                     org.apache.beam.sdk.testing.UsesCommittedMetrics,
                     org.apache.beam.sdk.testing.UsesTestStream,
                     org.apache.beam.sdk.testing.UsesSplittableParDo

http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
index d8771de..a0b015b 100644
--- 
a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
+++ 
b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/state/FlinkStateInternals.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.runners.flink.translation.wrappers.streaming.state;
 
+import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import java.nio.ByteBuffer;
 import java.util.Collections;
@@ -33,6 +34,7 @@ import org.apache.beam.sdk.state.BagState;
 import org.apache.beam.sdk.state.CombiningState;
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.State;
 import org.apache.beam.sdk.state.StateContext;
@@ -48,6 +50,7 @@ import org.apache.beam.sdk.util.CombineContextFactory;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.state.MapStateDescriptor;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.BooleanSerializer;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
 import org.apache.flink.runtime.state.KeyedStateBackend;
 import org.joda.time.Instant;
@@ -127,8 +130,8 @@ public class FlinkStateInternals<K> implements 
StateInternals {
           @Override
           public <T> SetState<T> bindSet(
               StateTag<SetState<T>> address, Coder<T> elemCoder) {
-            throw new UnsupportedOperationException(
-                String.format("%s is not supported", 
SetState.class.getSimpleName()));
+            return new FlinkSetState<>(
+                flinkStateBackend, address, namespace, elemCoder);
           }
 
           @Override
@@ -875,24 +878,15 @@ public class FlinkStateInternals<K> implements 
StateInternals {
 
     @Override
     public ReadableState<ValueT> get(final KeyT input) {
-      return new ReadableState<ValueT>() {
-        @Override
-        public ValueT read() {
-          try {
-            return flinkStateBackend.getPartitionedState(
+      try {
+        return ReadableStates.immediate(
+            flinkStateBackend.getPartitionedState(
                 namespace.stringKey(),
                 StringSerializer.INSTANCE,
-                flinkStateDescriptor).get(input);
-          } catch (Exception e) {
-            throw new RuntimeException("Error get from state.", e);
-          }
-        }
-
-        @Override
-        public ReadableState<ValueT> readLater() {
-          return this;
-        }
-      };
+                flinkStateDescriptor).get(input));
+      } catch (Exception e) {
+        throw new RuntimeException("Error get from state.", e);
+      }
     }
 
     @Override
@@ -909,32 +903,22 @@ public class FlinkStateInternals<K> implements 
StateInternals {
 
     @Override
     public ReadableState<ValueT> putIfAbsent(final KeyT key, final ValueT 
value) {
-      return new ReadableState<ValueT>() {
-        @Override
-        public ValueT read() {
-          try {
-            ValueT current = flinkStateBackend.getPartitionedState(
-                namespace.stringKey(),
-                StringSerializer.INSTANCE,
-                flinkStateDescriptor).get(key);
-
-            if (current == null) {
-              flinkStateBackend.getPartitionedState(
-                  namespace.stringKey(),
-                  StringSerializer.INSTANCE,
-                  flinkStateDescriptor).put(key, value);
-            }
-            return current;
-          } catch (Exception e) {
-            throw new RuntimeException("Error put kv to state.", e);
-          }
-        }
+      try {
+        ValueT current = flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).get(key);
 
-        @Override
-        public ReadableState<ValueT> readLater() {
-          return this;
+        if (current == null) {
+          flinkStateBackend.getPartitionedState(
+              namespace.stringKey(),
+              StringSerializer.INSTANCE,
+              flinkStateDescriptor).put(key, value);
         }
-      };
+        return ReadableStates.immediate(current);
+      } catch (Exception e) {
+        throw new RuntimeException("Error put kv to state.", e);
+      }
     }
 
     @Override
@@ -955,10 +939,11 @@ public class FlinkStateInternals<K> implements 
StateInternals {
         @Override
         public Iterable<KeyT> read() {
           try {
-            return flinkStateBackend.getPartitionedState(
+            Iterable<KeyT> result = flinkStateBackend.getPartitionedState(
                 namespace.stringKey(),
                 StringSerializer.INSTANCE,
                 flinkStateDescriptor).keys();
+            return result != null ? result : Collections.<KeyT>emptyList();
           } catch (Exception e) {
             throw new RuntimeException("Error get map state keys.", e);
           }
@@ -977,10 +962,11 @@ public class FlinkStateInternals<K> implements 
StateInternals {
         @Override
         public Iterable<ValueT> read() {
           try {
-            return flinkStateBackend.getPartitionedState(
+            Iterable<ValueT> result = flinkStateBackend.getPartitionedState(
                 namespace.stringKey(),
                 StringSerializer.INSTANCE,
                 flinkStateDescriptor).values();
+            return result != null ? result : Collections.<ValueT>emptyList();
           } catch (Exception e) {
             throw new RuntimeException("Error get map state values.", e);
           }
@@ -999,10 +985,11 @@ public class FlinkStateInternals<K> implements 
StateInternals {
         @Override
         public Iterable<Map.Entry<KeyT, ValueT>> read() {
           try {
-            return flinkStateBackend.getPartitionedState(
+            Iterable<Map.Entry<KeyT, ValueT>> result = 
flinkStateBackend.getPartitionedState(
                 namespace.stringKey(),
                 StringSerializer.INSTANCE,
                 flinkStateDescriptor).entries();
+            return result != null ? result : Collections.<Map.Entry<KeyT, 
ValueT>>emptyList();
           } catch (Exception e) {
             throw new RuntimeException("Error get map state entries.", e);
           }
@@ -1050,4 +1037,154 @@ public class FlinkStateInternals<K> implements 
StateInternals {
     }
   }
 
+  private static class FlinkSetState<T> implements SetState<T> {
+
+    private final StateNamespace namespace;
+    private final StateTag<SetState<T>> address;
+    private final MapStateDescriptor<T, Boolean> flinkStateDescriptor;
+    private final KeyedStateBackend<ByteBuffer> flinkStateBackend;
+
+    FlinkSetState(
+        KeyedStateBackend<ByteBuffer> flinkStateBackend,
+        StateTag<SetState<T>> address,
+        StateNamespace namespace,
+        Coder<T> coder) {
+      this.namespace = namespace;
+      this.address = address;
+      this.flinkStateBackend = flinkStateBackend;
+      this.flinkStateDescriptor = new MapStateDescriptor<>(address.getId(),
+          new CoderTypeSerializer<>(coder), new BooleanSerializer());
+    }
+
+    @Override
+    public ReadableState<Boolean> contains(final T t) {
+      try {
+        Boolean result = flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).get(t);
+        return ReadableStates.immediate(result != null ? result : false);
+      } catch (Exception e) {
+        throw new RuntimeException("Error contains value from state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> addIfAbsent(final T t) {
+      try {
+        org.apache.flink.api.common.state.MapState<T, Boolean> state =
+            flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor);
+        boolean alreadyContained = state.contains(t);
+        if (!alreadyContained) {
+          state.put(t, true);
+        }
+        return ReadableStates.immediate(!alreadyContained);
+      } catch (Exception e) {
+        throw new RuntimeException("Error addIfAbsent value to state.", e);
+      }
+    }
+
+    @Override
+    public void remove(T t) {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).remove(t);
+      } catch (Exception e) {
+        throw new RuntimeException("Error remove value to state.", e);
+      }
+    }
+
+    @Override
+    public SetState<T> readLater() {
+      return this;
+    }
+
+    @Override
+    public void add(T value) {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).put(value, true);
+      } catch (Exception e) {
+        throw new RuntimeException("Error add value to state.", e);
+      }
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public Boolean read() {
+          try {
+            Iterable<T> result = flinkStateBackend.getPartitionedState(
+                namespace.stringKey(),
+                StringSerializer.INSTANCE,
+                flinkStateDescriptor).keys();
+            return result == null || Iterables.isEmpty(result);
+          } catch (Exception e) {
+            throw new RuntimeException("Error isEmpty from state.", e);
+          }
+        }
+
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+      };
+    }
+
+    @Override
+    public Iterable<T> read() {
+      try {
+        Iterable<T> result = flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).keys();
+        return result != null ? result : Collections.<T>emptyList();
+      } catch (Exception e) {
+        throw new RuntimeException("Error read from state.", e);
+      }
+    }
+
+    @Override
+    public void clear() {
+      try {
+        flinkStateBackend.getPartitionedState(
+            namespace.stringKey(),
+            StringSerializer.INSTANCE,
+            flinkStateDescriptor).clear();
+      } catch (Exception e) {
+        throw new RuntimeException("Error clearing state.", e);
+      }
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      FlinkSetState<?> that = (FlinkSetState<?>) o;
+
+      return namespace.equals(that.namespace) && address.equals(that.address);
+
+    }
+
+    @Override
+    public int hashCode() {
+      int result = namespace.hashCode();
+      result = 31 * result + address.hashCode();
+      return result;
+    }
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/10b166b3/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
index e7564ec..b8d41de 100644
--- 
a/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
+++ 
b/runners/flink/src/test/java/org/apache/beam/runners/flink/streaming/FlinkStateInternalsTest.java
@@ -63,21 +63,4 @@ public class FlinkStateInternalsTest extends 
StateInternalsTest {
     }
   }
 
-  ///////////////////////// Unsupported tests 
\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\
-
-  @Override
-  public void testSet() {}
-
-  @Override
-  public void testSetIsEmpty() {}
-
-  @Override
-  public void testMergeSetIntoSource() {}
-
-  @Override
-  public void testMergeSetIntoNewNamespace() {}
-
-  @Override
-  public void testMap() {}
-
 }

Reply via email to