Add SetState and MapState

Project: http://git-wip-us.apache.org/repos/asf/beam/repo
Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a0702f5b
Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a0702f5b
Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a0702f5b

Branch: refs/heads/master
Commit: a0702f5bed3c7269e90b4702266945aa34dd1aea
Parents: 0f48321
Author: JingsongLi <lzljs3620...@aliyun.com>
Authored: Tue Feb 14 14:52:05 2017 +0800
Committer: Kenneth Knowles <k...@google.com>
Committed: Tue Feb 14 11:06:29 2017 -0800

----------------------------------------------------------------------
 .../translation/utils/ApexStateInternals.java   |  18 ++
 .../runners/core/InMemoryStateInternals.java    | 205 ++++++++++++++
 .../apache/beam/runners/core/StateMerging.java  |  44 +++
 .../org/apache/beam/runners/core/StateTag.java  |   8 +
 .../org/apache/beam/runners/core/StateTags.java |  30 ++
 .../core/InMemoryStateInternalsTest.java        | 280 +++++++++++++++++--
 .../apache/beam/runners/core/StateTagTest.java  |  33 +++
 .../CopyOnAccessInMemoryStateInternals.java     |  46 +++
 .../CopyOnAccessInMemoryStateInternalsTest.java |  58 ++++
 .../wrappers/streaming/FlinkStateInternals.java |  18 ++
 .../apache/beam/sdk/util/state/MapState.java    |  93 ++++++
 .../apache/beam/sdk/util/state/SetState.java    |  71 +++++
 .../apache/beam/sdk/util/state/StateBinder.java |   6 +
 .../apache/beam/sdk/util/state/StateSpecs.java  |  89 ++++++
 .../apache/beam/sdk/transforms/ParDoTest.java   |  94 +++++++
 15 files changed, 1063 insertions(+), 30 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
index 34d993f..7634366 100644
--- 
a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
+++ 
b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java
@@ -45,7 +45,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
 import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateContext;
 import org.apache.beam.sdk.util.state.StateContexts;
@@ -121,6 +123,22 @@ public class ApexStateInternals<K> implements 
StateInternals<K>, Serializable {
     }
 
     @Override
+    public <T> SetState<T> bindSet(
+        StateTag<? super K, SetState<T>> address,
+        Coder<T> elemCoder) {
+      throw new UnsupportedOperationException(
+          String.format("%s is not supported", 
SetState.class.getSimpleName()));
+    }
+
+    @Override
+    public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+        StateTag<? super K, MapState<KeyT, ValueT>> spec,
+        Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+      throw new UnsupportedOperationException(
+          String.format("%s is not supported", 
MapState.class.getSimpleName()));
+    }
+
+    @Override
     public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, 
OutputT>
         bindCombiningValue(
             StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java
index 6a181f3..b4b2b38 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java
@@ -17,10 +17,16 @@
  */
 package org.apache.beam.runners.core;
 
+import static com.google.common.base.Preconditions.checkNotNull;
+
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
+import java.util.Set;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.core.StateTag.StateBinder;
 import org.apache.beam.sdk.annotations.Experimental;
@@ -34,7 +40,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
 import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateContext;
 import org.apache.beam.sdk.util.state.StateContexts;
@@ -128,6 +136,18 @@ public class InMemoryStateInternals<K> implements 
StateInternals<K> {
     }
 
     @Override
+    public <T> SetState<T> bindSet(StateTag<? super K, SetState<T>> spec, 
Coder<T> elemCoder) {
+      return new InMemorySet<>();
+    }
+
+    @Override
+    public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+        StateTag<? super K, MapState<KeyT, ValueT>> spec,
+        Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+      return new InMemoryMap<>();
+    }
+
+    @Override
     public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, 
OutputT>
         bindCombiningValue(
             StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> address,
@@ -435,4 +455,189 @@ public class InMemoryStateInternals<K> implements 
StateInternals<K> {
       return that;
     }
   }
+
+  /**
+   * An {@link InMemoryState} implementation of {@link SetState}.
+   */
+  public static final class InMemorySet<T> implements SetState<T>, 
InMemoryState<InMemorySet<T>> {
+    private Set<T> contents = new HashSet<>();
+
+    @Override
+    public void clear() {
+      contents = new HashSet<>();
+    }
+
+    @Override
+    public boolean contains(T t) {
+      return contents.contains(t);
+    }
+
+    @Override
+    public boolean addIfAbsent(T t) {
+      return contents.add(t);
+    }
+
+    @Override
+    public void remove(T t) {
+      contents.remove(t);
+    }
+
+    @Override
+    public SetState<T> readLater(Iterable<T> elements) {
+      return this;
+    }
+
+    @Override
+    public boolean containsAny(Iterable<T> elements) {
+      elements = checkNotNull(elements);
+      for (T t : elements) {
+        if (contents.contains(t)) {
+          return true;
+        }
+      }
+      return false;
+    }
+
+    @Override
+    public boolean containsAll(Iterable<T> elements) {
+      elements = checkNotNull(elements);
+      for (T t : elements) {
+        if (!contents.contains(t)) {
+          return false;
+        }
+      }
+      return true;
+    }
+
+    @Override
+    public InMemorySet<T> readLater() {
+      return this;
+    }
+
+    @Override
+    public Iterable<T> read() {
+      return contents;
+    }
+
+    @Override
+    public void add(T input) {
+      contents.add(input);
+    }
+
+    @Override
+    public boolean isCleared() {
+      return contents.isEmpty();
+    }
+
+    @Override
+    public ReadableState<Boolean> isEmpty() {
+      return new ReadableState<Boolean>() {
+        @Override
+        public ReadableState<Boolean> readLater() {
+          return this;
+        }
+
+        @Override
+        public Boolean read() {
+          return contents.isEmpty();
+        }
+      };
+    }
+
+    @Override
+    public InMemorySet<T> copy() {
+      InMemorySet<T> that = new InMemorySet<>();
+      that.contents.addAll(this.contents);
+      return that;
+    }
+  }
+
+  /**
+   * An {@link InMemoryState} implementation of {@link MapState}.
+   */
+  public static final class InMemoryMap<K, V> implements
+      MapState<K, V>, InMemoryState<InMemoryMap<K, V>> {
+    private Map<K, V> contents = new HashMap<>();
+
+    @Override
+    public void clear() {
+      contents = new HashMap<>();
+    }
+
+    @Override
+    public V get(K key) {
+      return contents.get(key);
+    }
+
+    @Override
+    public void put(K key, V value) {
+      contents.put(key, value);
+    }
+
+    @Override
+    public V putIfAbsent(K key, V value) {
+      V v = contents.get(key);
+      if (v == null) {
+        v = contents.put(key, value);
+      }
+
+      return v;
+    }
+
+    @Override
+    public void remove(K key) {
+      contents.remove(key);
+    }
+
+    @Override
+    public Iterable<V> get(Iterable<K> keys) {
+      List<V> values = new ArrayList<>();
+      for (K k : keys) {
+        values.add(contents.get(k));
+      }
+      return values;
+    }
+
+    @Override
+    public MapState<K, V> getLater(K k) {
+      return this;
+    }
+
+    @Override
+    public MapState<K, V> getLater(Iterable<K> keys) {
+      return this;
+    }
+
+    @Override
+    public Iterable<K> keys() {
+      return contents.keySet();
+    }
+
+    @Override
+    public Iterable<V> values() {
+      return contents.values();
+    }
+
+    @Override
+    public MapState<K, V> iterateLater() {
+      return this;
+    }
+
+    @Override
+    public Iterable<Map.Entry<K, V>> iterate() {
+      return contents.entrySet();
+    }
+
+    @Override
+    public boolean isCleared() {
+      return contents.isEmpty();
+    }
+
+    @Override
+    public InMemoryMap<K, V> copy() {
+      InMemoryMap<K, V> that = new InMemoryMap<>();
+      that.contents.putAll(this.contents);
+      return that;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java
 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java
index c533f83..e98d098 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java
@@ -28,6 +28,7 @@ import 
org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
 import org.apache.beam.sdk.util.state.CombiningState;
 import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.WatermarkHoldState;
 import org.joda.time.Instant;
@@ -112,6 +113,49 @@ public class StateMerging {
   }
 
   /**
+   * Merge all set state in {@code address} across all windows under merge.
+   */
+  public static <K, T, W extends BoundedWindow> void mergeSets(
+      MergingStateAccessor<K, W> context, StateTag<? super K, SetState<T>> 
address) {
+    mergeSets(context.accessInEachMergingWindow(address).values(), 
context.access(address));
+  }
+
+  /**
+   * Merge all set state in {@code sources} (which may include {@code result}) 
into {@code result}.
+   */
+  public static <T, W extends BoundedWindow> void mergeSets(
+      Collection<SetState<T>> sources, SetState<T> result) {
+    if (sources.isEmpty()) {
+      // Nothing to merge.
+      return;
+    }
+    // Prefetch everything except what's already in result.
+    List<ReadableState<Iterable<T>>> futures = new ArrayList<>(sources.size());
+    for (SetState<T> source : sources) {
+      if (!source.equals(result)) {
+        prefetchRead(source);
+        futures.add(source);
+      }
+    }
+    if (futures.isEmpty()) {
+      // Result already holds all the values.
+      return;
+    }
+    // Transfer from sources to result.
+    for (ReadableState<Iterable<T>> future : futures) {
+      for (T element : future.read()) {
+        result.add(element);
+      }
+    }
+    // Clear sources except for result.
+    for (SetState<T> source : sources) {
+      if (!source.equals(result)) {
+        source.clear();
+      }
+    }
+  }
+
+  /**
    * Prefetch all combining value state for {@code address} across all merging 
windows in {@code
    * context}.
    */

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java
index a3d703f..802aede 100644
--- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java
+++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java
@@ -30,6 +30,8 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.util.state.ValueState;
@@ -86,6 +88,12 @@ public interface StateTag<K, StateT extends State> extends 
Serializable {
 
     <T> BagState<T> bindBag(StateTag<? super K, BagState<T>> spec, Coder<T> 
elemCoder);
 
+    <T> SetState<T> bindSet(StateTag<? super K, SetState<T>> spec, Coder<T> 
elemCoder);
+
+    <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+        StateTag<? super K, MapState<KeyT, ValueT>> spec,
+        Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder);
+
     <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, 
OutputT> bindCombiningValue(
         StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, 
OutputT>> spec,
         Coder<AccumT> accumCoder,

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
index cf7c236..1c70dff 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
@@ -32,6 +32,8 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
 import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateBinder;
 import org.apache.beam.sdk.util.state.StateSpec;
@@ -68,6 +70,19 @@ public class StateTags {
       }
 
       @Override
+      public <T> SetState<T> bindSet(
+          String id, StateSpec<? super K, SetState<T>> spec, Coder<T> 
elemCoder) {
+        return binder.bindSet(tagForSpec(id, spec), elemCoder);
+      }
+
+      @Override
+      public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+          String id, StateSpec<? super K, MapState<KeyT, ValueT>> spec,
+          Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+        return binder.bindMap(tagForSpec(id, spec), mapKeyCoder, 
mapValueCoder);
+      }
+
+      @Override
       public <InputT, AccumT, OutputT>
           AccumulatorCombiningState<InputT, AccumT, OutputT> 
bindCombiningValue(
               String id,
@@ -200,6 +215,21 @@ public class StateTags {
   }
 
   /**
+   * Create a state spec that supporting for {@link java.util.Set} like access 
patterns.
+   */
+  public static <T> StateTag<Object, SetState<T>> set(String id, Coder<T> 
elemCoder) {
+    return new SimpleStateTag<>(new StructuredId(id), 
StateSpecs.set(elemCoder));
+  }
+
+  /**
+   * Create a state spec that supporting for {@link java.util.Map} like access 
patterns.
+   */
+  public static <K, V> StateTag<Object, MapState<K, V>> map(
+      String id, Coder<K> keyCoder, Coder<V> valueCoder) {
+    return new SimpleStateTag<>(new StructuredId(id), StateSpecs.map(keyCoder, 
valueCoder));
+  }
+
+  /**
    * Create a state tag for holding the watermark.
    */
   public static <W extends BoundedWindow> StateTag<Object, 
WatermarkHoldState<W>>

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java
 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java
index 8ea9abc..1da946f 100644
--- 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java
+++ 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java
@@ -17,11 +17,18 @@
  */
 package org.apache.beam.runners.core;
 
+import static org.hamcrest.Matchers.equalTo;
+import static org.hamcrest.Matchers.not;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
 
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import java.util.Objects;
 import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.transforms.Sum;
@@ -31,7 +38,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFns;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
 import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.MapState;
 import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.ValueState;
 import org.apache.beam.sdk.util.state.WatermarkHoldState;
 import org.hamcrest.Matchers;
@@ -57,6 +66,10 @@ public class InMemoryStateInternalsTest {
           "sumInteger", VarIntCoder.of(), Sum.ofIntegers());
   private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR =
       StateTags.bag("stringBag", StringUtf8Coder.of());
+  private static final StateTag<Object, SetState<String>> STRING_SET_ADDR =
+      StateTags.set("stringSet", StringUtf8Coder.of());
+  private static final StateTag<Object, MapState<String, Integer>> 
STRING_MAP_ADDR =
+      StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of());
   private static final StateTag<Object, WatermarkHoldState<BoundedWindow>>
       WATERMARK_EARLIEST_ADDR =
       StateTags.watermarkStateInternal("watermark", 
OutputTimeFns.outputAtEarliestInputTimestamp());
@@ -80,9 +93,9 @@ public class InMemoryStateInternalsTest {
 
     assertThat(value.read(), Matchers.nullValue());
     value.write("hello");
-    assertThat(value.read(), Matchers.equalTo("hello"));
+    assertThat(value.read(), equalTo("hello"));
     value.write("world");
-    assertThat(value.read(), Matchers.equalTo("world"));
+    assertThat(value.read(), equalTo("world"));
 
     value.clear();
     assertThat(value.read(), Matchers.nullValue());
@@ -94,8 +107,8 @@ public class InMemoryStateInternalsTest {
     BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR);
 
     // State instances are cached, but depend on the namespace.
-    assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR));
-    assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)));
+    assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_BAG_ADDR)));
+    assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, 
STRING_BAG_ADDR))));
 
     assertThat(value.read(), Matchers.emptyIterable());
     value.add("hello");
@@ -157,6 +170,213 @@ public class InMemoryStateInternalsTest {
   }
 
   @Test
+  public void testSet() throws Exception {
+    SetState<String> value = underTest.state(NAMESPACE_1, STRING_SET_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_SET_ADDR)));
+    assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, 
STRING_SET_ADDR))));
+
+    // empty
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertFalse(value.contains("A"));
+    assertFalse(value.containsAny(Collections.singletonList("A")));
+
+    // add
+    value.add("A");
+    value.add("B");
+    value.add("A");
+    assertFalse(value.addIfAbsent("B"));
+    assertThat(value.read(), Matchers.containsInAnyOrder("A", "B"));
+
+    // remove
+    value.remove("A");
+    assertThat(value.read(), Matchers.containsInAnyOrder("B"));
+    value.remove("C");
+    assertThat(value.read(), Matchers.containsInAnyOrder("B"));
+
+    // contains
+    assertFalse(value.contains("A"));
+    assertTrue(value.contains("B"));
+    value.add("C");
+    value.add("D");
+
+    // containsAny
+    assertTrue(value.containsAny(Arrays.asList("A", "C")));
+    assertFalse(value.containsAny(Arrays.asList("A", "E")));
+
+    // containsAll
+    assertTrue(value.containsAll(Arrays.asList("B", "C")));
+    assertFalse(value.containsAll(Arrays.asList("A", "B")));
+
+    // readLater
+    assertThat(value.readLater().read(), Matchers.containsInAnyOrder("B", "C", 
"D"));
+    SetState<String> later = value.readLater(Arrays.asList("A", "C", "D"));
+    assertTrue(later.containsAll(Arrays.asList("C", "D")));
+    assertFalse(later.contains("A"));
+
+    // clear
+    value.clear();
+    assertThat(value.read(), Matchers.emptyIterable());
+    assertThat(underTest.state(NAMESPACE_1, STRING_SET_ADDR), 
Matchers.sameInstance(value));
+
+  }
+
+  @Test
+  public void testSetIsEmpty() throws Exception {
+    SetState<String> value = underTest.state(NAMESPACE_1, STRING_SET_ADDR);
+
+    assertThat(value.isEmpty().read(), Matchers.is(true));
+    ReadableState<Boolean> readFuture = value.isEmpty();
+    value.add("hello");
+    assertThat(readFuture.read(), Matchers.is(false));
+
+    value.clear();
+    assertThat(readFuture.read(), Matchers.is(true));
+  }
+
+  @Test
+  public void testMergeSetIntoSource() throws Exception {
+    SetState<String> set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR);
+    SetState<String> set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR);
+
+    set1.add("Hello");
+    set2.add("Hello");
+    set2.add("World");
+    set1.add("!");
+
+    StateMerging.mergeSets(Arrays.asList(set1, set2), set1);
+
+    // Reading the merged set gets both the contents
+    assertThat(set1.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(set2.read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMergeSetIntoNewNamespace() throws Exception {
+    SetState<String> set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR);
+    SetState<String> set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR);
+    SetState<String> set3 = underTest.state(NAMESPACE_3, STRING_SET_ADDR);
+
+    set1.add("Hello");
+    set2.add("Hello");
+    set2.add("World");
+    set1.add("!");
+
+    StateMerging.mergeSets(Arrays.asList(set1, set2, set3), set3);
+
+    // Reading the merged set gets both the contents
+    assertThat(set3.read(), Matchers.containsInAnyOrder("Hello", "World", 
"!"));
+    assertThat(set1.read(), Matchers.emptyIterable());
+    assertThat(set2.read(), Matchers.emptyIterable());
+  }
+
+  // for testMap
+  private static class MapEntry<K, V> implements Map.Entry<K, V> {
+    private K key;
+    private V value;
+
+    private MapEntry(K key, V value) {
+      this.key = key;
+      this.value = value;
+    }
+
+    static <K, V> Map.Entry<K, V> of(K k, V v) {
+      return new MapEntry<>(k, v);
+    }
+
+    public final K getKey() {
+      return key;
+    }
+    public final V getValue() {
+      return value;
+    }
+
+    public final String toString() {
+      return key + "=" + value;
+    }
+
+    public final int hashCode() {
+      return Objects.hashCode(key) ^ Objects.hashCode(value);
+    }
+
+    public final V setValue(V newValue) {
+      V oldValue = value;
+      value = newValue;
+      return oldValue;
+    }
+
+    public final boolean equals(Object o) {
+      if (o == this) {
+        return true;
+      }
+      if (o instanceof Map.Entry) {
+        Map.Entry<?, ?> e = (Map.Entry<?, ?>) o;
+        if (Objects.equals(key, e.getKey())
+            && Objects.equals(value, e.getValue())) {
+          return true;
+        }
+      }
+      return false;
+    }
+  }
+
+  @Test
+  public void testMap() throws Exception {
+    MapState<String, Integer> value = underTest.state(NAMESPACE_1, 
STRING_MAP_ADDR);
+
+    // State instances are cached, but depend on the namespace.
+    assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_MAP_ADDR)));
+    assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, 
STRING_MAP_ADDR))));
+
+    // put
+    assertThat(value.iterate(), Matchers.emptyIterable());
+    value.put("A", 1);
+    value.put("B", 2);
+    value.put("A", 11);
+    assertThat(value.putIfAbsent("B", 22), equalTo(2));
+    assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("A", 
11),
+        MapEntry.of("B", 2)));
+
+    // remove
+    value.remove("A");
+    assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("B", 
2)));
+    value.remove("C");
+    assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("B", 
2)));
+
+    // get
+    assertNull(value.get("A"));
+    assertThat(value.get("B"), equalTo(2));
+    value.put("C", 3);
+    value.put("D", 4);
+    assertThat(value.get("C"), equalTo(3));
+    assertThat(value.get(Collections.singletonList("D")), 
Matchers.containsInAnyOrder(4));
+    assertThat(value.get(Arrays.asList("B", "C")), 
Matchers.containsInAnyOrder(2, 3));
+
+    // iterate
+    value.put("E", 5);
+    value.remove("C");
+    assertThat(value.keys(), Matchers.containsInAnyOrder("B", "D", "E"));
+    assertThat(value.values(), Matchers.containsInAnyOrder(2, 4, 5));
+    assertThat(value.iterate(), Matchers.containsInAnyOrder(
+        MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5)));
+
+    // readLater
+    assertThat(value.getLater("B").get("B"), equalTo(2));
+    assertNull(value.getLater("A").get("A"));
+    MapState<String, Integer> later = value.getLater(Arrays.asList("C", "D"));
+    assertNull(later.get("C"));
+    assertThat(later.get("D"), equalTo(4));
+    assertThat(value.iterateLater().iterate(), Matchers.containsInAnyOrder(
+        MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5)));
+
+    // clear
+    value.clear();
+    assertThat(value.iterate(), Matchers.emptyIterable());
+    assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), 
Matchers.sameInstance(value));
+  }
+
+  @Test
   public void testCombiningValue() throws Exception {
     CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, 
SUM_INTEGER_ADDR);
 
@@ -164,15 +384,15 @@ public class InMemoryStateInternalsTest {
     assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR));
     assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR)));
 
-    assertThat(value.read(), Matchers.equalTo(0));
+    assertThat(value.read(), equalTo(0));
     value.add(2);
-    assertThat(value.read(), Matchers.equalTo(2));
+    assertThat(value.read(), equalTo(2));
 
     value.add(3);
-    assertThat(value.read(), Matchers.equalTo(5));
+    assertThat(value.read(), equalTo(5));
 
     value.clear();
-    assertThat(value.read(), Matchers.equalTo(0));
+    assertThat(value.read(), equalTo(0));
     assertThat(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), 
Matchers.sameInstance(value));
   }
 
@@ -200,14 +420,14 @@ public class InMemoryStateInternalsTest {
     value2.add(10);
     value1.add(6);
 
-    assertThat(value1.read(), Matchers.equalTo(11));
-    assertThat(value2.read(), Matchers.equalTo(10));
+    assertThat(value1.read(), equalTo(11));
+    assertThat(value2.read(), equalTo(10));
 
     // Merging clears the old values and updates the result value.
     StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1);
 
-    assertThat(value1.read(), Matchers.equalTo(21));
-    assertThat(value2.read(), Matchers.equalTo(0));
+    assertThat(value1.read(), equalTo(21));
+    assertThat(value2.read(), equalTo(0));
   }
 
   @Test
@@ -226,9 +446,9 @@ public class InMemoryStateInternalsTest {
     StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3);
 
     // Merging clears the old values and updates the result value.
-    assertThat(value1.read(), Matchers.equalTo(0));
-    assertThat(value2.read(), Matchers.equalTo(0));
-    assertThat(value3.read(), Matchers.equalTo(21));
+    assertThat(value1.read(), equalTo(0));
+    assertThat(value2.read(), equalTo(0));
+    assertThat(value3.read(), equalTo(21));
   }
 
   @Test
@@ -242,16 +462,16 @@ public class InMemoryStateInternalsTest {
 
     assertThat(value.read(), Matchers.nullValue());
     value.add(new Instant(2000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+    assertThat(value.read(), equalTo(new Instant(2000)));
 
     value.add(new Instant(3000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+    assertThat(value.read(), equalTo(new Instant(2000)));
 
     value.add(new Instant(1000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(1000)));
+    assertThat(value.read(), equalTo(new Instant(1000)));
 
     value.clear();
-    assertThat(value.read(), Matchers.equalTo(null));
+    assertThat(value.read(), equalTo(null));
     assertThat(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), 
Matchers.sameInstance(value));
   }
 
@@ -266,16 +486,16 @@ public class InMemoryStateInternalsTest {
 
     assertThat(value.read(), Matchers.nullValue());
     value.add(new Instant(2000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+    assertThat(value.read(), equalTo(new Instant(2000)));
 
     value.add(new Instant(3000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+    assertThat(value.read(), equalTo(new Instant(3000)));
 
     value.add(new Instant(1000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(3000)));
+    assertThat(value.read(), equalTo(new Instant(3000)));
 
     value.clear();
-    assertThat(value.read(), Matchers.equalTo(null));
+    assertThat(value.read(), equalTo(null));
     assertThat(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), 
Matchers.sameInstance(value));
   }
 
@@ -289,10 +509,10 @@ public class InMemoryStateInternalsTest {
 
     assertThat(value.read(), Matchers.nullValue());
     value.add(new Instant(2000));
-    assertThat(value.read(), Matchers.equalTo(new Instant(2000)));
+    assertThat(value.read(), equalTo(new Instant(2000)));
 
     value.clear();
-    assertThat(value.read(), Matchers.equalTo(null));
+    assertThat(value.read(), equalTo(null));
     assertThat(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), 
Matchers.sameInstance(value));
   }
 
@@ -325,8 +545,8 @@ public class InMemoryStateInternalsTest {
     // Merging clears the old values and updates the merged value.
     StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, 
WINDOW_1);
 
-    assertThat(value1.read(), Matchers.equalTo(new Instant(2000)));
-    assertThat(value2.read(), Matchers.equalTo(null));
+    assertThat(value1.read(), equalTo(new Instant(2000)));
+    assertThat(value2.read(), equalTo(null));
   }
 
   @Test
@@ -347,8 +567,8 @@ public class InMemoryStateInternalsTest {
     StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, 
WINDOW_1);
 
     // Merging clears the old values and updates the result value.
-    assertThat(value3.read(), Matchers.equalTo(new Instant(5000)));
-    assertThat(value1.read(), Matchers.equalTo(null));
-    assertThat(value2.read(), Matchers.equalTo(null));
+    assertThat(value3.read(), equalTo(new Instant(5000)));
+    assertThat(value1.read(), equalTo(null));
+    assertThat(value2.read(), equalTo(null));
   }
 }

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java
----------------------------------------------------------------------
diff --git 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java
 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java
index 9a04628..0584643 100644
--- 
a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java
+++ 
b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java
@@ -23,6 +23,7 @@ import static org.junit.Assert.assertNotEquals;
 import org.apache.beam.sdk.coders.BigEndianIntegerCoder;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.CoderRegistry;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
 import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.Max;
@@ -63,6 +64,38 @@ public class StateTagTest {
   }
 
   @Test
+  public void testSetEquality() {
+    StateTag<?, ?> fooVarInt1 = StateTags.set("foo", VarIntCoder.of());
+    StateTag<?, ?> fooVarInt2 = StateTags.set("foo", VarIntCoder.of());
+    StateTag<?, ?> fooBigEndian = StateTags.set("foo", 
BigEndianIntegerCoder.of());
+    StateTag<?, ?> barVarInt = StateTags.set("bar", VarIntCoder.of());
+
+    assertEquals(fooVarInt1, fooVarInt2);
+    assertNotEquals(fooVarInt1, fooBigEndian);
+    assertNotEquals(fooVarInt1, barVarInt);
+  }
+
+  @Test
+  public void testMapEquality() {
+    StateTag<?, ?> fooStringVarInt1 =
+        StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of());
+    StateTag<?, ?> fooStringVarInt2 =
+        StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of());
+    StateTag<?, ?> fooStringBigEndian =
+        StateTags.map("foo", StringUtf8Coder.of(), BigEndianIntegerCoder.of());
+    StateTag<?, ?> fooVarIntBigEndian =
+        StateTags.map("foo", VarIntCoder.of(), BigEndianIntegerCoder.of());
+    StateTag<?, ?> barStringVarInt =
+        StateTags.map("bar", StringUtf8Coder.of(), VarIntCoder.of());
+
+    assertEquals(fooStringVarInt1, fooStringVarInt2);
+    assertNotEquals(fooStringVarInt1, fooStringBigEndian);
+    assertNotEquals(fooStringBigEndian, fooVarIntBigEndian);
+    assertNotEquals(fooStringVarInt1, fooVarIntBigEndian);
+    assertNotEquals(fooStringVarInt1, barStringVarInt);
+  }
+
+  @Test
   public void testWatermarkBagEquality() {
     StateTag<?, ?> foo1 = StateTags.watermarkStateInternal(
         "foo", OutputTimeFns.outputAtEarliestInputTimestamp());

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java
 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java
index 47c0251..ff5c23c 100644
--- 
a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java
+++ 
b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java
@@ -27,6 +27,8 @@ import java.util.Map;
 import javax.annotation.Nullable;
 import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryBag;
 import 
org.apache.beam.runners.core.InMemoryStateInternals.InMemoryCombiningValue;
+import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryMap;
+import org.apache.beam.runners.core.InMemoryStateInternals.InMemorySet;
 import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryState;
 import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryStateBinder;
 import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryValue;
@@ -45,6 +47,8 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn;
 import org.apache.beam.sdk.util.CombineFnUtil;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateContext;
 import org.apache.beam.sdk.util.state.StateContexts;
@@ -334,6 +338,35 @@ public class CopyOnAccessInMemoryStateInternals<K> 
implements StateInternals<K>
           }
 
           @Override
+          public <T> SetState<T> bindSet(
+              StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) {
+            if (containedInUnderlying(namespace, address)) {
+              @SuppressWarnings("unchecked")
+              InMemoryState<? extends SetState<T>> existingState =
+                  (InMemoryState<? extends SetState<T>>)
+                      underlying.get().get(namespace, address, c);
+              return existingState.copy();
+            } else {
+              return new InMemorySet<>();
+            }
+          }
+
+          @Override
+          public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+              StateTag<? super K, MapState<KeyT, ValueT>> address,
+              Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+            if (containedInUnderlying(namespace, address)) {
+              @SuppressWarnings("unchecked")
+              InMemoryState<? extends MapState<KeyT, ValueT>> existingState =
+                  (InMemoryState<? extends MapState<KeyT, ValueT>>)
+                      underlying.get().get(namespace, address, c);
+              return existingState.copy();
+            } else {
+              return new InMemoryMap<>();
+            }
+          }
+
+          @Override
           public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, 
AccumT, OutputT>
               bindKeyedCombiningValue(
                   StateTag<? super K, AccumulatorCombiningState<InputT, 
AccumT, OutputT>> address,
@@ -430,6 +463,19 @@ public class CopyOnAccessInMemoryStateInternals<K> 
implements StateInternals<K>
           }
 
           @Override
+          public <T> SetState<T> bindSet(
+              StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) {
+            return underlying.get(namespace, address, c);
+          }
+
+          @Override
+          public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+              StateTag<? super K, MapState<KeyT, ValueT>> address,
+              Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+            return underlying.get(namespace, address, c);
+          }
+
+          @Override
           public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, 
AccumT, OutputT>
               bindKeyedCombiningValue(
                   StateTag<? super K, AccumulatorCombiningState<InputT, 
AccumT, OutputT>> address,

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java
----------------------------------------------------------------------
diff --git 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java
 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java
index c8eb66e..c7409bb 100644
--- 
a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java
+++ 
b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java
@@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 import static org.hamcrest.Matchers.theInstance;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertThat;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
@@ -47,6 +48,8 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFns;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
 import org.apache.beam.sdk.util.state.CombiningState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.ValueState;
 import org.apache.beam.sdk.util.state.WatermarkHoldState;
 import org.joda.time.Instant;
@@ -164,6 +167,61 @@ public class CopyOnAccessInMemoryStateInternalsTest {
   }
 
   @Test
+  public void testSetStateWithUnderlying() {
+    CopyOnAccessInMemoryStateInternals<String> underlying =
+        CopyOnAccessInMemoryStateInternals.withUnderlying(key, null);
+
+    StateNamespace namespace = new StateNamespaceForTest("foo");
+    StateTag<Object, SetState<Integer>> valueTag = StateTags.set("foo", 
VarIntCoder.of());
+    SetState<Integer> underlyingValue = underlying.state(namespace, valueTag);
+    assertThat(underlyingValue.read(), emptyIterable());
+
+    underlyingValue.add(1);
+    assertThat(underlyingValue.read(), containsInAnyOrder(1));
+
+    CopyOnAccessInMemoryStateInternals<String> internals =
+        CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying);
+    SetState<Integer> copyOnAccessState = internals.state(namespace, valueTag);
+    assertThat(copyOnAccessState.read(), containsInAnyOrder(1));
+
+    copyOnAccessState.add(4);
+    assertThat(copyOnAccessState.read(), containsInAnyOrder(4, 1));
+    assertThat(underlyingValue.read(), containsInAnyOrder(1));
+
+    SetState<Integer> reReadUnderlyingValue = underlying.state(namespace, 
valueTag);
+    assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read()));
+  }
+
+  @Test
+  public void testMapStateWithUnderlying() {
+    CopyOnAccessInMemoryStateInternals<String> underlying =
+        CopyOnAccessInMemoryStateInternals.withUnderlying(key, null);
+
+    StateNamespace namespace = new StateNamespaceForTest("foo");
+    StateTag<Object, MapState<String, Integer>> valueTag =
+        StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of());
+    MapState<String, Integer> underlyingValue = underlying.state(namespace, 
valueTag);
+    assertThat(underlyingValue.iterate(), emptyIterable());
+
+    underlyingValue.put("hello", 1);
+    assertThat(underlyingValue.get("hello"), equalTo(1));
+
+    CopyOnAccessInMemoryStateInternals<String> internals =
+        CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying);
+    MapState<String, Integer> copyOnAccessState = internals.state(namespace, 
valueTag);
+    assertThat(copyOnAccessState.get("hello"), equalTo(1));
+
+    copyOnAccessState.put("world", 4);
+    assertThat(copyOnAccessState.get("hello"), equalTo(1));
+    assertThat(copyOnAccessState.get("world"), equalTo(4));
+    assertThat(underlyingValue.get("hello"), equalTo(1));
+    assertNull(underlyingValue.get("world"));
+
+    MapState<String, Integer> reReadUnderlyingValue = 
underlying.state(namespace, valueTag);
+    assertThat(underlyingValue.iterate(), 
equalTo(reReadUnderlyingValue.iterate()));
+  }
+
+  @Test
   public void testAccumulatorCombiningStateWithUnderlying() throws 
CannotProvideCoderException {
     CopyOnAccessInMemoryStateInternals<String> underlying =
         CopyOnAccessInMemoryStateInternals.withUnderlying(key, null);

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java
----------------------------------------------------------------------
diff --git 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java
 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java
index eaededb..4183067 100644
--- 
a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java
+++ 
b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java
@@ -37,7 +37,9 @@ import org.apache.beam.sdk.util.CoderUtils;
 import org.apache.beam.sdk.util.CombineContextFactory;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
 import org.apache.beam.sdk.util.state.ReadableState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.State;
 import org.apache.beam.sdk.util.state.StateContext;
 import org.apache.beam.sdk.util.state.StateContexts;
@@ -125,6 +127,22 @@ public class FlinkStateInternals<K> implements 
StateInternals<K> {
       }
 
       @Override
+      public <T> SetState<T> bindSet(
+          StateTag<? super K, SetState<T>> address,
+          Coder<T> elemCoder) {
+        throw new UnsupportedOperationException(
+            String.format("%s is not supported", 
SetState.class.getSimpleName()));
+      }
+
+      @Override
+      public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+          StateTag<? super K, MapState<KeyT, ValueT>> spec,
+          Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) {
+        throw new UnsupportedOperationException(
+            String.format("%s is not supported", 
MapState.class.getSimpleName()));
+      }
+
+      @Override
       public <InputT, AccumT, OutputT>
           AccumulatorCombiningState<InputT, AccumT, OutputT>
       bindCombiningValue(

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java
new file mode 100644
index 0000000..85d99d6
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.util.state;
+
+import java.util.Map;
+
+/**
+ * An object that maps keys to values.
+ * A map cannot contain duplicate keys;
+ * each key can map to at most one value.
+ *
+ * @param <K> the type of keys maintained by this map
+ * @param <V> the type of mapped values
+ */
+public interface MapState<K, V> extends State {
+
+  /**
+   * Returns the value to which the specified key is mapped in the state.
+   */
+  V get(K key);
+
+  /**
+   * Associates the specified value with the specified key in this state.
+   */
+  void put(K key, V value);
+
+  /**
+   * If the specified key is not already associated with a value (or is mapped
+   * to {@code null}) associates it with the given value and returns
+   * {@code null}, else returns the current value.
+   */
+  V putIfAbsent(K key, V value);
+
+  /**
+   * Removes the mapping for a key from this map if it is present.
+   */
+  void remove(K key);
+
+  /**
+   * A bulk get.
+   * @param keys the keys to search for
+   * @return a iterable view of values, maybe some values is null.
+   * The order of values corresponds to the order of the keys.
+   */
+  Iterable<V> get(Iterable<K> keys);
+
+  /**
+   * Indicate that specified key will be read later.
+   */
+  MapState<K, V> getLater(K k);
+
+  /**
+   * Indicate that specified batch keys will be read later.
+   */
+  MapState<K, V> getLater(Iterable<K> keys);
+
+  /**
+   * Returns a iterable view of the keys contained in this map.
+   */
+  Iterable<K> keys();
+
+  /**
+   * Returns a iterable view of the values contained in this map.
+   */
+  Iterable<V> values();
+
+  /**
+   * Indicate that all key-values will be read later.
+   */
+  MapState<K, V> iterateLater();
+
+  /**
+   * Returns a iterable view of all key-values.
+   */
+  Iterable<Map.Entry<K, V>> iterate();
+
+}
+

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java
new file mode 100644
index 0000000..93058b2
--- /dev/null
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.util.state;
+
+/**
+ * State containing no duplicate elements.
+ * Items can be added to the set and the contents read out.
+ *
+ * @param <T> The type of elements in the set.
+ */
+public interface SetState<T> extends CombiningState<T, Iterable<T>> {
+  /**
+   * Returns true if this set contains the specified element.
+   */
+  boolean contains(T t);
+
+  /**
+   * Add a value to the buffer if it is not already present.
+   * If this set already contains the element, the call leaves the set
+   * unchanged and returns false.
+   */
+  boolean addIfAbsent(T t);
+
+  /**
+   * Removes the specified element from this set if it is present.
+   */
+  void remove(T t);
+
+  /**
+   * Indicate that elements will be read later.
+   * @param elements to be read later
+   * @return this for convenient chaining
+   */
+  SetState<T> readLater(Iterable<T> elements);
+
+  /**
+   * <p>Checks if SetState contains any given elements.</p>
+   *
+   * @param elements the elements to search for
+   * @return the {@code true} if any of the elements are found,
+   * {@code false} if no match
+   */
+  boolean containsAny(Iterable<T> elements);
+
+  /**
+   * <p>Checks if SetState contains all given elements.</p>
+   *
+   * @param elements the elements to find
+   * @return true if the SetState contains all elements,
+   *  false if not
+   */
+  boolean containsAll(Iterable<T> elements);
+
+  @Override
+  SetState<T> readLater();
+}

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java
index 0521e15..fbfb475 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java
@@ -33,6 +33,12 @@ public interface StateBinder<K> {
 
   <T> BagState<T> bindBag(String id, StateSpec<? super K, BagState<T>> spec, 
Coder<T> elemCoder);
 
+  <T> SetState<T> bindSet(String id, StateSpec<? super K, SetState<T>> spec, 
Coder<T> elemCoder);
+
+  <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+      String id, StateSpec<? super K, MapState<KeyT, ValueT>> spec,
+      Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder);
+
   <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> 
bindCombiningValue(
       String id,
       StateSpec<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> 
spec,

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
index 08c3a12..8912993 100644
--- 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
+++ 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java
@@ -125,6 +125,21 @@ public class StateSpecs {
     return new BagStateSpec<T>(elemCoder);
   }
 
+  /**
+   * Create a state spec that supporting for {@link java.util.Set} like access 
patterns.
+   */
+  public static <T> StateSpec<Object, SetState<T>> set(Coder<T> elemCoder) {
+    return new SetStateSpec<>(elemCoder);
+  }
+
+  /**
+   * Create a state spec that supporting for {@link java.util.Map} like access 
patterns.
+   */
+  public static <K, V> StateSpec<Object, MapState<K, V>> map(Coder<K> keyCoder,
+                                                             Coder<V> 
valueCoder) {
+    return new MapStateSpec<>(keyCoder, valueCoder);
+  }
+
   /** Create a state spec for holding the watermark. */
   public static <W extends BoundedWindow>
       StateSpec<Object, WatermarkHoldState<W>> watermarkStateInternal(
@@ -346,6 +361,80 @@ public class StateSpecs {
     }
   }
 
+  private static class MapStateSpec<K, V> implements StateSpec<Object, 
MapState<K, V>> {
+
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private MapStateSpec(Coder<K> keyCoder, Coder<V> valueCoder) {
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+    }
+
+    @Override
+    public MapState<K, V> bind(String id, StateBinder<?> visitor) {
+      return visitor.bindMap(id, this, keyCoder, valueCoder);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (obj == this) {
+        return true;
+      }
+
+      if (!(obj instanceof MapStateSpec)) {
+        return false;
+      }
+
+      MapStateSpec<?, ?> that = (MapStateSpec<?, ?>) obj;
+      return Objects.equals(this.keyCoder, that.keyCoder)
+          && Objects.equals(this.valueCoder, that.valueCoder);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(getClass(), keyCoder, valueCoder);
+    }
+  }
+
+  /**
+   * A specification for a state cell supporting for set-like access patterns.
+   *
+   * <p>Includes the coder for the element type {@code T}</p>
+   */
+  private static class SetStateSpec<T> implements StateSpec<Object, 
SetState<T>> {
+
+    private final Coder<T> elemCoder;
+
+    private SetStateSpec(Coder<T> elemCoder) {
+      this.elemCoder = elemCoder;
+    }
+
+    @Override
+    public SetState<T> bind(String id, StateBinder<?> visitor) {
+      return visitor.bindSet(id, this, elemCoder);
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+      if (obj == this) {
+        return true;
+      }
+
+      if (!(obj instanceof SetStateSpec)) {
+        return false;
+      }
+
+      SetStateSpec<?> that = (SetStateSpec<?>) obj;
+      return Objects.equals(this.elemCoder, that.elemCoder);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(getClass(), elemCoder);
+    }
+  }
+
   /**
    * A specification for a state cell tracking a combined watermark hold.
    *

http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
----------------------------------------------------------------------
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 75c39cc..f40bbe1 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -40,6 +40,7 @@ import com.fasterxml.jackson.annotation.JsonCreator;
 import com.google.common.base.MoreObjects;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
@@ -48,6 +49,8 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import org.apache.beam.sdk.Pipeline.PipelineExecutionException;
 import org.apache.beam.sdk.coders.AtomicCoder;
 import org.apache.beam.sdk.coders.CoderException;
@@ -60,6 +63,8 @@ import org.apache.beam.sdk.testing.PAssert;
 import org.apache.beam.sdk.testing.RunnableOnService;
 import org.apache.beam.sdk.testing.TestPipeline;
 import org.apache.beam.sdk.testing.TestStream;
+import org.apache.beam.sdk.testing.UsesMapState;
+import org.apache.beam.sdk.testing.UsesSetState;
 import org.apache.beam.sdk.testing.UsesStatefulParDo;
 import org.apache.beam.sdk.testing.UsesTestStream;
 import org.apache.beam.sdk.testing.UsesTimersInParDo;
@@ -83,6 +88,8 @@ import org.apache.beam.sdk.util.TimerSpecs;
 import org.apache.beam.sdk.util.common.ElementByteSizeObserver;
 import org.apache.beam.sdk.util.state.AccumulatorCombiningState;
 import org.apache.beam.sdk.util.state.BagState;
+import org.apache.beam.sdk.util.state.MapState;
+import org.apache.beam.sdk.util.state.SetState;
 import org.apache.beam.sdk.util.state.StateSpec;
 import org.apache.beam.sdk.util.state.StateSpecs;
 import org.apache.beam.sdk.util.state.ValueState;
@@ -1681,6 +1688,93 @@ public class ParDoTest implements Serializable {
   }
 
   @Test
+  @Category({RunnableOnService.class, UsesStatefulParDo.class, 
UsesSetState.class})
+  public void testSetState() {
+    final String stateId = "foo";
+    final String countStateId = "count";
+
+    DoFn<KV<String, Integer>, Set<Integer>> fn =
+        new DoFn<KV<String, Integer>, Set<Integer>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, SetState<Integer>> setState =
+              StateSpecs.set(VarIntCoder.of());
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c,
+              @StateId(stateId) SetState<Integer> state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer>
+                  count) {
+            state.add(c.element().getValue());
+            count.add(1);
+            if (count.read() >= 4) {
+              Set<Integer> set = Sets.newHashSet(state.read());
+              c.output(set);
+            }
+          }
+        };
+
+    PCollection<Set<Integer>> output =
+        pipeline.apply(
+            Create.of(
+                KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), 
KV.of("hello", 12)))
+            .apply(ParDo.of(fn));
+
+    PAssert.that(output).containsInAnyOrder(Sets.newHashSet(97, 42, 12));
+    pipeline.run();
+  }
+
+  @Test
+  @Category({RunnableOnService.class, UsesStatefulParDo.class, 
UsesMapState.class})
+  public void testMapState() {
+    final String stateId = "foo";
+    final String countStateId = "count";
+
+    DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>> fn =
+        new DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>>() {
+
+          @StateId(stateId)
+          private final StateSpec<Object, MapState<String, Integer>> mapState =
+              StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of());
+          @StateId(countStateId)
+          private final StateSpec<Object, AccumulatorCombiningState<Integer, 
int[], Integer>>
+              countState = 
StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(),
+              Sum.ofIntegers());
+
+          @ProcessElement
+          public void processElement(
+              ProcessContext c, @StateId(stateId) MapState<String, Integer> 
state,
+              @StateId(countStateId) AccumulatorCombiningState<Integer, int[], 
Integer>
+                  count) {
+            KV<String, Integer> value = c.element().getValue();
+            state.put(value.getKey(), value.getValue());
+            count.add(1);
+            if (count.read() >= 4) {
+              Iterable<Map.Entry<String, Integer>> iterate = state.iterate();
+              for (Map.Entry<String, Integer> entry : iterate) {
+                c.output(KV.of(entry.getKey(), entry.getValue()));
+              }
+            }
+          }
+        };
+
+    PCollection<KV<String, Integer>> output =
+        pipeline.apply(
+            Create.of(
+                KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)),
+                KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 
12))))
+            .apply(ParDo.of(fn));
+
+    PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), 
KV.of("c", 12));
+    pipeline.run();
+  }
+
+  @Test
   @Category({RunnableOnService.class, UsesStatefulParDo.class})
   public void testCombiningState() {
     final String stateId = "foo";

Reply via email to