This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c6a827  Add MapState and SetState support
     new 20108fd  Merge pull request #15238 from kileys/beam-12588-multimapstate
4c6a827 is described below

commit 4c6a8271679ec003e195baa30f24541caa0bf9f0
Author: kileys <kiley...@google.com>
AuthorDate: Tue Jun 29 04:03:26 2021 +0000

    Add MapState and SetState support
---
 .../dataflow/BatchStatefulParDoOverrides.java      |   4 +-
 .../dataflow/DataflowPipelineTranslator.java       |   3 +-
 .../beam/runners/dataflow/DataflowRunner.java      |  22 +-
 .../beam/runners/dataflow/DataflowRunnerTest.java  |  17 +
 .../java/org/apache/beam/sdk/state/MapState.java   |  12 -
 .../java/org/apache/beam/sdk/state/SetState.java   |   5 +-
 .../beam/fn/harness/state/FnApiStateAccessor.java  | 239 ++++++++-
 .../beam/fn/harness/state/MultimapUserState.java   | 292 +++++++++++
 .../fn/harness/state/MultimapUserStateTest.java    | 557 +++++++++++++++++++++
 9 files changed, 1125 insertions(+), 26 deletions(-)

diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
index eb16ea2..229fdf6 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java
@@ -161,7 +161,7 @@ public class BatchStatefulParDoOverrides {
       verifyFnIsStateful(fn);
       DataflowPipelineOptions options =
           input.getPipeline().getOptions().as(DataflowPipelineOptions.class);
-      DataflowRunner.verifyDoFnSupported(fn, false, 
DataflowRunner.useStreamingEngine(options));
+      DataflowRunner.verifyDoFnSupported(fn, false, options);
       
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());
 
       PTransform<
@@ -189,7 +189,7 @@ public class BatchStatefulParDoOverrides {
       verifyFnIsStateful(fn);
       DataflowPipelineOptions options =
           input.getPipeline().getOptions().as(DataflowPipelineOptions.class);
-      DataflowRunner.verifyDoFnSupported(fn, false, 
DataflowRunner.useStreamingEngine(options));
+      DataflowRunner.verifyDoFnSupported(fn, false, options);
       
DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy());
 
       PTransform<
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
index 8229a32..2bf9975 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java
@@ -1247,8 +1247,7 @@ public class DataflowPipelineTranslator {
     boolean isStateful = DoFnSignatures.isStateful(fn);
     if (isStateful) {
       DataflowPipelineOptions options = context.getPipelineOptions();
-      DataflowRunner.verifyDoFnSupported(
-          fn, options.isStreaming(), 
DataflowRunner.useStreamingEngine(options));
+      DataflowRunner.verifyDoFnSupported(fn, options.isStreaming(), options);
       DataflowRunner.verifyStateSupportForWindowingStrategy(windowingStrategy);
     }
 
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index c1220b0..9cb8735 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -2293,27 +2293,35 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
         || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
   }
 
-  static void verifyDoFnSupported(DoFn<?, ?> fn, boolean streaming, boolean 
streamingEngine) {
+  static void verifyDoFnSupported(
+      DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
     if (streaming && DoFnSignatures.requiresTimeSortedInput(fn)) {
       throw new UnsupportedOperationException(
           String.format(
               "%s does not currently support @RequiresTimeSortedInput in 
streaming mode.",
               DataflowRunner.class.getSimpleName()));
     }
+
+    boolean streamingEngine = useStreamingEngine(options);
+    boolean isUnifiedWorker = useUnifiedWorker(options);
     if (DoFnSignatures.usesSetState(fn)) {
-      if (streaming && streamingEngine) {
+      if (streaming && (isUnifiedWorker || streamingEngine)) {
         throw new UnsupportedOperationException(
             String.format(
-                "%s does not currently support %s when using streaming engine",
-                DataflowRunner.class.getSimpleName(), 
SetState.class.getSimpleName()));
+                "%s does not currently support %s when using %s",
+                DataflowRunner.class.getSimpleName(),
+                SetState.class.getSimpleName(),
+                isUnifiedWorker ? "streaming on unified worker" : "streaming 
engine"));
       }
     }
     if (DoFnSignatures.usesMapState(fn)) {
-      if (streaming && streamingEngine) {
+      if (streaming && (isUnifiedWorker || streamingEngine)) {
         throw new UnsupportedOperationException(
             String.format(
-                "%s does not currently support %s when using streaming engine",
-                DataflowRunner.class.getSimpleName(), 
MapState.class.getSimpleName()));
+                "%s does not currently support %s when using %s",
+                DataflowRunner.class.getSimpleName(),
+                MapState.class.getSimpleName(),
+                isUnifiedWorker ? "streaming on unified worker" : "streaming 
engine"));
       }
     }
   }
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index a212fd1..e7df0ba 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -1539,6 +1539,15 @@ public class DataflowRunnerTest implements Serializable {
     verifyMapStateUnsupported(options);
   }
 
+  @Test
+  public void testMapStateUnsupportedStreamingUnifiedRunner() throws Exception 
{
+    PipelineOptions options = buildPipelineOptions();
+    ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), 
"use_unified_worker");
+    options.as(DataflowPipelineOptions.class).setStreaming(true);
+
+    verifyMapStateUnsupported(options);
+  }
+
   private void verifySetStateUnsupported(PipelineOptions options) throws 
Exception {
     Pipeline p = Pipeline.create(options);
     p.apply(Create.of(KV.of(13, 42)))
@@ -1566,6 +1575,14 @@ public class DataflowRunnerTest implements Serializable {
     verifySetStateUnsupported(options);
   }
 
+  @Test
+  public void testSetStateUnsupportedStreamingUnifiedWorker() throws Exception 
{
+    PipelineOptions options = buildPipelineOptions();
+    ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), 
"use_unified_worker");
+    options.as(DataflowPipelineOptions.class).setStreaming(true);
+    verifySetStateUnsupported(options);
+  }
+
   /** Records all the composite transforms visited within the Pipeline. */
   private static class CompositeTransformRecorder extends 
PipelineVisitor.Defaults {
 
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java
index 6c05ba8..bbbe6cd 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java
@@ -56,12 +56,6 @@ public interface MapState<K, V> extends State {
    * <p>Changes will not be reflected in the results returned by previous 
calls to {@link
    * ReadableState#read} on the results any of the reading methods ({@link 
#get}, {@link #keys},
    * {@link #values}, and {@link #entries}).
-   *
-   * <p>Since the condition is not evaluated until {@link ReadableState#read} 
is called, a call to
-   * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a 
read on the
-   * putIfAbsent return will result in the item being written to the map. 
Similarly, if there are
-   * multiple calls to {@link #putIfAbsent} for the same key, precedence will 
be given to the first
-   * one on which read is called.
    */
   default ReadableState<V> putIfAbsent(K key, V value) {
     return computeIfAbsent(key, k -> value);
@@ -79,12 +73,6 @@ public interface MapState<K, V> extends State {
    * <p>Changes will not be reflected in the results returned by previous 
calls to {@link
    * ReadableState#read} on the results any of the reading methods ({@link 
#get}, {@link #keys},
    * {@link #values}, and {@link #entries}).
-   *
-   * <p>Since the condition is not evaluated until {@link ReadableState#read} 
is called, a call to
-   * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a 
read on the
-   * putIfAbsent return will result in the item being written to the map. 
Similarly, if there are
-   * multiple calls to {@link #putIfAbsent} for the same key, precedence will 
be given to the first
-   * one on which read is called.
    */
   ReadableState<V> computeIfAbsent(K key, Function<? super K, ? extends V> 
mappingFunction);
 
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java
index 2ca7226..b4b7bf7 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java
@@ -30,7 +30,10 @@ import org.apache.beam.sdk.annotations.Experimental.Kind;
  */
 @Experimental(Kind.STATE)
 public interface SetState<T> extends GroupingState<T, Iterable<T>> {
-  /** Returns true if this set contains the specified element. */
+  /**
+   * Returns a {@link ReadableState} whose {@link #read} method will return 
true if this set
+   * contains the specified element at the point when that {@link #read} call 
returns.
+   */
   ReadableState<Boolean> contains(T t);
 
   /**
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
index 5a931c5..517be05 100644
--- 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java
@@ -31,6 +31,7 @@ import 
org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
 import org.apache.beam.runners.core.SideInputReader;
 import org.apache.beam.sdk.coders.Coder;
 import org.apache.beam.sdk.coders.KvCoder;
+import org.apache.beam.sdk.coders.VoidCoder;
 import org.apache.beam.sdk.function.ThrowingRunnable;
 import org.apache.beam.sdk.options.PipelineOptions;
 import org.apache.beam.sdk.state.BagState;
@@ -38,6 +39,7 @@ import org.apache.beam.sdk.state.CombiningState;
 import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.OrderedListState;
 import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
 import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.StateBinder;
 import org.apache.beam.sdk.state.StateContext;
@@ -54,6 +56,7 @@ import org.apache.beam.sdk.values.PCollectionView;
 import org.apache.beam.sdk.values.TupleTag;
 import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
 import org.checkerframework.checker.nullness.qual.Nullable;
 
@@ -324,7 +327,87 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
 
   @Override
   public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, 
Coder<T> elemCoder) {
-    throw new UnsupportedOperationException("TODO: Add support for a map state 
to the Fn API.");
+    return (SetState<T>)
+        stateKeyObjectCache.computeIfAbsent(
+            createMultimapUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new SetState<T>() {
+                  private final MultimapUserState<T, Void> impl =
+                      createMultimapUserState(id, elemCoder, VoidCoder.of());
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> contains(T t) {
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public Boolean read() {
+                        return !Iterables.isEmpty(impl.get(t));
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> addIfAbsent(T t) {
+                    boolean isEmpty = Iterables.isEmpty(impl.get(t));
+                    if (isEmpty) {
+                      impl.put(t, null);
+                    }
+                    // TODO: Support prefetching.
+                    return ReadableStates.immediate(isEmpty);
+                  }
+
+                  @Override
+                  public void remove(T t) {
+                    impl.remove(t);
+                  }
+
+                  @Override
+                  public void add(T value) {
+                    impl.remove(value);
+                    impl.put(value, null);
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public Boolean read() {
+                        return Iterables.isEmpty(impl.keys());
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public Iterable<T> read() {
+                    return impl.keys();
+                  }
+
+                  @Override
+                  public SetState<T> readLater() {
+                    // TODO: Support prefetching.
+                    return this;
+                  }
+                };
+              }
+            });
   }
 
   @Override
@@ -333,7 +416,133 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
       StateSpec<MapState<KeyT, ValueT>> spec,
       Coder<KeyT> mapKeyCoder,
       Coder<ValueT> mapValueCoder) {
-    throw new UnsupportedOperationException("TODO: Add support for a map state 
to the Fn API.");
+    return (MapState<KeyT, ValueT>)
+        stateKeyObjectCache.computeIfAbsent(
+            createMultimapUserStateKey(id),
+            new Function<StateKey, Object>() {
+              @Override
+              public Object apply(StateKey key) {
+                return new MapState<KeyT, ValueT>() {
+                  private final MultimapUserState<KeyT, ValueT> impl =
+                      createMultimapUserState(id, mapKeyCoder, mapValueCoder);
+
+                  @Override
+                  public void clear() {
+                    impl.clear();
+                  }
+
+                  @Override
+                  public void put(KeyT key, ValueT value) {
+                    impl.remove(key);
+                    impl.put(key, value);
+                  }
+
+                  @Override
+                  public ReadableState<ValueT> computeIfAbsent(
+                      KeyT key, Function<? super KeyT, ? extends ValueT> 
mappingFunction) {
+                    Iterable<ValueT> values = impl.get(key);
+                    if (Iterables.isEmpty(values)) {
+                      impl.put(key, mappingFunction.apply(key));
+                    }
+                    return 
ReadableStates.immediate(Iterables.getOnlyElement(values, null));
+                  }
+
+                  @Override
+                  public void remove(KeyT key) {
+                    impl.remove(key);
+                  }
+
+                  @Override
+                  public ReadableState<ValueT> get(KeyT key) {
+                    return getOrDefault(key, null);
+                  }
+
+                  @Override
+                  public ReadableState<ValueT> getOrDefault(
+                      KeyT key, @Nullable ValueT defaultValue) {
+                    return new ReadableState<ValueT>() {
+                      @Override
+                      public @Nullable ValueT read() {
+                        Iterable<ValueT> values = impl.get(key);
+                        return Iterables.getOnlyElement(values, defaultValue);
+                      }
+
+                      @Override
+                      public ReadableState<ValueT> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<KeyT>> keys() {
+                    return new ReadableState<Iterable<KeyT>>() {
+                      @Override
+                      public Iterable<KeyT> read() {
+                        return impl.keys();
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<KeyT>> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<ValueT>> values() {
+                    return new ReadableState<Iterable<ValueT>>() {
+                      @Override
+                      public Iterable<ValueT> read() {
+                        return Iterables.transform(entries().read(), e -> 
e.getValue());
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<ValueT>> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> 
entries() {
+                    return new ReadableState<Iterable<Map.Entry<KeyT, 
ValueT>>>() {
+                      @Override
+                      public Iterable<Map.Entry<KeyT, ValueT>> read() {
+                        Iterable<KeyT> keys = keys().read();
+                        return Iterables.transform(
+                            keys, key -> Maps.immutableEntry(key, 
get(key).read()));
+                      }
+
+                      @Override
+                      public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> 
readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+
+                  @Override
+                  public ReadableState<Boolean> isEmpty() {
+                    return new ReadableState<Boolean>() {
+                      @Override
+                      public Boolean read() {
+                        return Iterables.isEmpty(keys().read());
+                      }
+
+                      @Override
+                      public ReadableState<Boolean> readLater() {
+                        // TODO: Support prefetching.
+                        return this;
+                      }
+                    };
+                  }
+                };
+              }
+            });
   }
 
   @Override
@@ -481,6 +690,22 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
     throw new UnsupportedOperationException("WatermarkHoldState is unsupported 
by the Fn API.");
   }
 
+  private <KeyT, ValueT> MultimapUserState<KeyT, ValueT> 
createMultimapUserState(
+      String stateId, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) {
+    MultimapUserState<KeyT, ValueT> rval =
+        new MultimapUserState(
+            beamFnStateClient,
+            processBundleInstructionId.get(),
+            ptransformId,
+            stateId,
+            encodedCurrentWindowSupplier.get(),
+            encodedCurrentKeySupplier.get(),
+            keyCoder,
+            valueCoder);
+    stateFinalizers.add(rval::asyncClose);
+    return rval;
+  }
+
   private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> 
valueCoder) {
     BagUserState<T> rval =
         new BagUserState<>(
@@ -506,6 +731,16 @@ public class FnApiStateAccessor<K> implements 
SideInputReader, StateBinder {
     return builder.build();
   }
 
+  private StateKey createMultimapUserStateKey(String stateId) {
+    StateKey.Builder builder = StateKey.newBuilder();
+    builder
+        .getMultimapKeysUserStateBuilder()
+        .setWindow(encodedCurrentWindowSupplier.get())
+        .setTransformId(ptransformId)
+        .setUserStateId(stateId);
+    return builder.build();
+  }
+
   public void finalizeState() {
     // Persist all dirty state cells
     try {
diff --git 
a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
new file mode 100644
index 0000000..49efa35
--- /dev/null
+++ 
b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest;
+import org.apache.beam.sdk.coders.Coder;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
+import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;
+import org.checkerframework.checker.nullness.qual.Nullable;
+
+/**
+ * An implementation of a multimap user state that utilizes the Beam Fn State 
API to fetch, clear
+ * and persist values.
+ *
+ * <p>Calling {@link #asyncClose()} schedules any required persistence 
changes. This object should
+ * no longer be used after it is closed.
+ *
+ * <p>TODO: Move to an async persist model where persistence is signalled 
based upon cache memory
+ * pressure and its need to flush.
+ *
+ * <p>TODO: Support block level caching and prefetch.
+ */
+public class MultimapUserState<K, V> {
+
+  private final BeamFnStateClient beamFnStateClient;
+  private final Coder<K> mapKeyCoder;
+  private final Coder<V> valueCoder;
+  private final String stateId;
+  private final StateRequest keysStateRequest;
+  private final StateRequest userStateRequest;
+
+  private boolean isClosed;
+  private boolean isCleared;
+  // Pending updates to persistent storage
+  private HashSet<K> pendingRemoves = Sets.newHashSet();
+  private HashMap<K, List<V>> pendingAdds = Maps.newHashMap();
+  // Map keys with no values in persistent storage
+  private HashSet<K> negativeCache = Sets.newHashSet();
+  // Values retrieved from persistent storage
+  private Multimap<K, V> persistedValues = ArrayListMultimap.create();
+  private @Nullable Iterable<K> persistedKeys = null;
+
+  public MultimapUserState(
+      BeamFnStateClient beamFnStateClient,
+      String instructionId,
+      String pTransformId,
+      String stateId,
+      ByteString encodedWindow,
+      ByteString encodedKey,
+      Coder<K> mapKeyCoder,
+      Coder<V> valueCoder) {
+    this.beamFnStateClient = beamFnStateClient;
+    this.mapKeyCoder = mapKeyCoder;
+    this.valueCoder = valueCoder;
+    this.stateId = stateId;
+
+    StateRequest.Builder keysStateRequestBuilder = StateRequest.newBuilder();
+    keysStateRequestBuilder
+        .setInstructionId(instructionId)
+        .getStateKeyBuilder()
+        .getMultimapKeysUserStateBuilder()
+        .setTransformId(pTransformId)
+        .setUserStateId(stateId)
+        .setKey(encodedKey)
+        .setWindow(encodedWindow);
+    keysStateRequest = keysStateRequestBuilder.build();
+
+    StateRequest.Builder userStateRequestBuilder = StateRequest.newBuilder();
+    userStateRequestBuilder
+        .setInstructionId(instructionId)
+        .getStateKeyBuilder()
+        .getMultimapUserStateBuilder()
+        .setTransformId(pTransformId)
+        .setUserStateId(stateId)
+        .setWindow(encodedWindow)
+        .setKey(encodedKey);
+    userStateRequest = userStateRequestBuilder.build();
+  }
+
+  public void clear() {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+
+    isCleared = true;
+    persistedValues = ArrayListMultimap.create();
+    persistedKeys = null;
+    pendingRemoves = Sets.newHashSet();
+    pendingAdds = Maps.newHashMap();
+    negativeCache = Sets.newHashSet();
+  }
+
+  /*
+   * Returns an iterable of the values associated with key in this multimap, 
if any.
+   * If there are no values, this returns an empty collection, not null.
+   */
+  public Iterable<V> get(K key) {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+
+    List<V> pendingAddValues = pendingAdds.getOrDefault(key, 
Collections.emptyList());
+    Collection<V> pendingValues =
+        Collections.unmodifiableCollection(pendingAddValues.subList(0, 
pendingAddValues.size()));
+    if (isCleared || pendingRemoves.contains(key)) {
+      return pendingValues;
+    }
+
+    Iterable<V> persistedValues = getPersistedValues(key);
+    return Iterables.concat(persistedValues, pendingValues);
+  }
+
+  @SuppressWarnings({
+    "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-12687)
+  })
+  /*
+   * Returns an iterables containing all distinct keys in this multimap.
+   */
+  public Iterable<K> keys() {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+    if (isCleared) {
+      return 
Collections.unmodifiableCollection(Lists.newArrayList(pendingAdds.keySet()));
+    }
+
+    Set<K> keys = Sets.newHashSet(getPersistedKeys());
+    keys.removeAll(pendingRemoves);
+    keys.addAll(pendingAdds.keySet());
+    return Collections.unmodifiableCollection(keys);
+  }
+
+  /*
+   * Store a key-value pair in the multimap.
+   * Allows duplicate key-value pairs.
+   */
+  public void put(K key, V value) {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+    pendingAdds.putIfAbsent(key, new ArrayList<>());
+    pendingAdds.get(key).add(value);
+  }
+
+  /*
+   * Removes all values for this key in the multimap.
+   */
+  public void remove(K key) {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+    pendingAdds.remove(key);
+    if (!isCleared) {
+      pendingRemoves.add(key);
+    }
+  }
+
+  @SuppressWarnings({
+    "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-12687)
+  })
+  // Update data in persistent store
+  public void asyncClose() throws Exception {
+    checkState(
+        !isClosed,
+        "Multimap user state is no longer usable because it is closed for %s",
+        keysStateRequest.getStateKey());
+    isClosed = true;
+    // Nothing to persist
+    if (!isCleared && pendingRemoves.isEmpty() && pendingAdds.isEmpty()) {
+      return;
+    }
+
+    // Clear currently persisted key-values
+    if (isCleared) {
+      beamFnStateClient.handle(
+          
keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()),
+          new CompletableFuture<>());
+    } else if (!pendingRemoves.isEmpty()) {
+      for (K key : pendingRemoves) {
+        beamFnStateClient.handle(
+            createUserStateRequest(key)
+                .toBuilder()
+                .setClear(StateClearRequest.getDefaultInstance()),
+            new CompletableFuture<>());
+      }
+    }
+
+    // Persist pending key-values
+    if (!pendingAdds.isEmpty()) {
+      for (Map.Entry<K, List<V>> entry : pendingAdds.entrySet()) {
+        beamFnStateClient.handle(
+            createUserStateRequest(entry.getKey())
+                .toBuilder()
+                
.setAppend(StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))),
+            new CompletableFuture<>());
+      }
+    }
+  }
+
+  private ByteString encodeValues(Iterable<V> values) {
+    try {
+      ByteString.Output output = ByteString.newOutput();
+      for (V value : values) {
+        valueCoder.encode(value, output);
+      }
+      return output.toByteString();
+    } catch (IOException e) {
+      throw new IllegalStateException(
+          String.format("Failed to encode values for multimap user state id 
%s.", stateId), e);
+    }
+  }
+
+  private StateRequest createUserStateRequest(K key) {
+    try {
+      ByteString.Output output = ByteString.newOutput();
+      mapKeyCoder.encode(key, output);
+      StateRequest.Builder request = userStateRequest.toBuilder();
+      
request.getStateKeyBuilder().getMultimapUserStateBuilder().setMapKey(output.toByteString());
+      return request.build();
+    } catch (IOException e) {
+      throw new IllegalStateException(
+          String.format("Failed to encode key for multimap user state id %s.", 
stateId), e);
+    }
+  }
+
+  private Iterable<V> getPersistedValues(K key) {
+    if (negativeCache.contains(key)) {
+      return Collections.emptyList();
+    }
+
+    if (persistedValues.get(key).isEmpty()) {
+      Iterable<V> values =
+          StateFetchingIterators.readAllAndDecodeStartingFrom(
+              beamFnStateClient, createUserStateRequest(key), valueCoder);
+      if (Iterables.isEmpty(values)) {
+        negativeCache.add(key);
+      }
+      persistedValues.putAll(key, values);
+    }
+    return Iterables.unmodifiableIterable(persistedValues.get(key));
+  }
+
+  private Iterable<K> getPersistedKeys() {
+    checkState(!isCleared);
+    if (persistedKeys == null) {
+      Iterable<K> keys =
+          StateFetchingIterators.readAllAndDecodeStartingFrom(
+              beamFnStateClient, keysStateRequest, mapKeyCoder);
+      persistedKeys = Iterables.unmodifiableIterable(keys);
+    }
+    return persistedKeys;
+  }
+}
diff --git 
a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
new file mode 100644
index 0000000..23f1be4
--- /dev/null
+++ 
b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java
@@ -0,0 +1,557 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.fn.harness.state;
+
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.emptyIterable;
+import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertThrows;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Map;
+import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey;
+import org.apache.beam.sdk.coders.NullableCoder;
+import org.apache.beam.sdk.coders.StringUtf8Coder;
+import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+@RunWith(JUnit4.class)
+public class MultimapUserStateTest {
+
+  private final String pTransformId = "pTransformId";
+  private final String stateId = "stateId";
+  private final String encodedKey = "encodedKey";
+  private final String encodedWindow = "encodedWindow";
+
+  @Test
+  public void testNoPersistedValues() throws Exception {
+    FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    assertThat(userState.keys(), is(emptyIterable()));
+  }
+
+  @Test
+  public void testGet() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+
+    Iterable<String> initValues = userState.get("A1");
+    userState.put("A1", "V3");
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
+    assertArrayEquals(
+        new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get("A1"), String.class));
+    assertArrayEquals(new String[] {}, Iterables.toArray(userState.get("A2"), 
String.class));
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.get("A1"));
+  }
+
+  @Test
+  public void testClear() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+
+    Iterable<String> initValues = userState.get("A1");
+    userState.clear();
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
+    assertThat(userState.get("A1"), is(emptyIterable()));
+    assertThat(userState.keys(), is(emptyIterable()));
+
+    userState.put("A1", "V1");
+    userState.clear();
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
+    assertThat(userState.get("A1"), is(emptyIterable()));
+    assertThat(userState.keys(), is(emptyIterable()));
+
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.clear());
+  }
+
+  @Test
+  public void testKeys() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+
+    userState.put("A2", "V1");
+    Iterable<String> initKeys = userState.keys();
+    userState.put("A3", "V1");
+    userState.put("A1", "V3");
+    assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, 
String.class));
+    assertArrayEquals(
+        new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.keys(), 
String.class));
+
+    userState.clear();
+    assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, 
String.class));
+    assertArrayEquals(new String[] {}, Iterables.toArray(userState.keys(), 
String.class));
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.keys());
+  }
+
+  @Test
+  public void testPut() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+
+    Iterable<String> initValues = userState.get("A1");
+    userState.put("A1", "V3");
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
+    assertArrayEquals(
+        new String[] {"V1", "V2", "V3"}, 
Iterables.toArray(userState.get("A1"), String.class));
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.put("A1", "V2"));
+  }
+
+  @Test
+  public void testPutAfterRemove() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A0"),
+                createMultimapValueStateKey("A0"),
+                encode("V1")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.remove("A0");
+    userState.put("A0", "V2");
+    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get("A0"), String.class));
+    userState.asyncClose();
+    Map<StateKey, ByteString> data = fakeClient.getData();
+    assertEquals(encode("V2"), data.get(createMultimapValueStateKey("A0")));
+  }
+
+  @Test
+  public void testPutAfterClear() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A0"),
+                createMultimapValueStateKey("A0"),
+                encode("V1")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.clear();
+    userState.put("A0", "V2");
+    assertArrayEquals(new String[] {"V2"}, 
Iterables.toArray(userState.get("A0"), String.class));
+  }
+
+  @Test
+  public void testRemoveBeforeClear() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A0"),
+                createMultimapValueStateKey("A0"),
+                encode("V1")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.remove("A0");
+    userState.clear();
+    userState.asyncClose();
+    // Clear takes precedence over specific key remove
+    assertThat(fakeClient.getCallCount(), is(1));
+  }
+
+  @Test
+  public void testPutBeforeClear() throws Exception {
+    FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.put("A0", "V0");
+    userState.put("A1", "V1");
+    Iterable<String> values = userState.get("A1"); // fakeClient call = 1
+    userState.clear(); // fakeClient call = 2
+    assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, 
String.class));
+    userState.asyncClose();
+    // Clear takes precedence over puts
+    assertThat(fakeClient.getCallCount(), is(2));
+  }
+
+  @Test
+  public void testPutBeforeRemove() throws Exception {
+    FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.put("A0", "V0");
+    userState.put("A1", "V1");
+    Iterable<String> values = userState.get("A1"); // fakeClient call = 1
+    userState.remove("A0"); // fakeClient call = 2
+    userState.remove("A1"); // fakeClient call = 3
+    assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, 
String.class));
+    userState.asyncClose();
+    assertThat(fakeClient.getCallCount(), is(3));
+    assertNull(fakeClient.getData().get(createMultimapValueStateKey("A0")));
+    assertNull(fakeClient.getData().get(createMultimapValueStateKey("A1")));
+  }
+
+  @Test
+  public void testRemove() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+
+    Iterable<String> initValues = userState.get("A1");
+    userState.put("A1", "V3");
+
+    userState.remove("A1");
+    assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, 
String.class));
+    assertThat(userState.keys(), is(emptyIterable()));
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.remove("A1"));
+  }
+
+  @Test
+  public void testImmutableKeys() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    Iterable<String> keys = userState.keys();
+    assertThrows(
+        UnsupportedOperationException.class, () -> Iterables.removeAll(keys, 
Arrays.asList("A1")));
+  }
+
+  @Test
+  public void testImmutableValues() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    Iterable<String> values = userState.get("A1");
+    assertThrows(
+        UnsupportedOperationException.class,
+        () -> Iterables.removeAll(values, Arrays.asList("V1")));
+  }
+
+  @Test
+  public void testClearAsyncClose() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.clear();
+    userState.asyncClose();
+    Map<StateKey, ByteString> data = fakeClient.getData();
+    assertEquals(1, data.size());
+    assertNull(data.get(createMultimapKeyStateKey()));
+  }
+
+  @Test
+  public void testNoopAsyncClose() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.asyncClose();
+    assertThrows(IllegalStateException.class, () -> userState.keys());
+    assertEquals(0, fakeClient.getCallCount());
+  }
+
+  @Test
+  public void testAsyncClose() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A0", "A1"),
+                createMultimapValueStateKey("A0"),
+                encode("V1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.remove("A0");
+    userState.put("A1", "V3");
+    userState.put("A2", "V1");
+    userState.put("A3", "V1");
+    userState.remove("A3");
+    userState.asyncClose();
+    Map<StateKey, ByteString> data = fakeClient.getData();
+    assertNull(data.get(createMultimapValueStateKey("A0")));
+    assertEquals(encode("V1", "V2", "V3"), 
data.get(createMultimapValueStateKey("A1")));
+    assertEquals(encode("V1"), data.get(createMultimapValueStateKey("A2")));
+  }
+
+  @Test
+  public void testNullKeysAndValues() throws Exception {
+    FakeBeamFnStateClient fakeClient =
+        new FakeBeamFnStateClient(
+            ImmutableMap.of(
+                createMultimapKeyStateKey(),
+                encode("A1"),
+                createMultimapValueStateKey("A1"),
+                encode("V1", "V2")));
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            NullableCoder.of(StringUtf8Coder.of()),
+            NullableCoder.of(StringUtf8Coder.of()));
+    userState.put(null, null);
+    userState.put(null, null);
+    userState.put(null, "V1");
+    assertArrayEquals(
+        new String[] {null, null, "V1"}, 
Iterables.toArray(userState.get(null), String.class));
+  }
+
+  @Test
+  public void testNegativeCache() throws Exception {
+    FakeBeamFnStateClient fakeClient = new 
FakeBeamFnStateClient(Collections.emptyMap());
+    MultimapUserState<String, String> userState =
+        new MultimapUserState<>(
+            fakeClient,
+            "instructionId",
+            pTransformId,
+            stateId,
+            encode(encodedWindow),
+            encode(encodedKey),
+            StringUtf8Coder.of(),
+            StringUtf8Coder.of());
+    userState.get("A1");
+    userState.get("A1");
+    assertThat(fakeClient.getCallCount(), is(1));
+  }
+
+  private StateKey createMultimapKeyStateKey() throws IOException {
+    return StateKey.newBuilder()
+        .setMultimapKeysUserState(
+            StateKey.MultimapKeysUserState.newBuilder()
+                .setWindow(encode(encodedWindow))
+                .setKey(encode(encodedKey))
+                .setTransformId(pTransformId)
+                .setUserStateId(stateId))
+        .build();
+  }
+
+  private StateKey createMultimapValueStateKey(String key) throws IOException {
+    return StateKey.newBuilder()
+        .setMultimapUserState(
+            StateKey.MultimapUserState.newBuilder()
+                .setTransformId(pTransformId)
+                .setUserStateId(stateId)
+                .setWindow(encode(encodedWindow))
+                .setKey(encode(encodedKey))
+                .setMapKey(encode(key)))
+        .build();
+  }
+
+  private ByteString encode(String... values) throws IOException {
+    ByteString.Output out = ByteString.newOutput();
+    for (String value : values) {
+      StringUtf8Coder.of().encode(value, out);
+    }
+    return out.toByteString();
+  }
+}

Reply via email to