scwhittle commented on code in PR #23492:
URL: https://github.com/apache/beam/pull/23492#discussion_r1058832514


##########
runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java:
##########
@@ -2422,12 +2422,6 @@ static boolean 
useStreamingEngine(DataflowPipelineOptions options) {
 
   static void verifyDoFnSupported(
       DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
-    if (DoFnSignatures.usesMultimapState(fn)) {

Review Comment:
   From PR description this sounds streaming only. If so, update this to still 
reject batch multimap usage. If not, update the PR description.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -334,6 +358,28 @@ public <T> Future<Iterable<TimestampedValue<T>>> 
orderedListFuture(
         valuesToPagingIterableFuture(stateTag, elemCoder, 
this.stateFuture(stateTag, elemCoder)));
   }
 
+  public <T> Future<Iterable<Map.Entry<ByteString, Iterable<T>>>> 
multimapFetchAllFuture(
+      boolean omitValues, ByteString encodedTag, String stateFamily, Coder<T> 
elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_ALL, encodedTag, stateFamily)
+            .toBuilder()
+            .setOmitValues(omitValues)
+            .build();
+    return Preconditions.checkNotNull(
+        valuesToPagingIterableFuture(stateTag, elemCoder, 
this.stateFuture(stateTag, elemCoder)));
+  }
+
+  public <T> Future<Iterable<T>> multimapFetchSingleEntryFuture(
+      ByteString encodedKey, ByteString encodedTag, String stateFamily, 
Coder<T> elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_SINGLE_ENTRY, encodedTag, 
stateFamily)
+            .toBuilder()
+            .setMultimapKey(encodedKey)
+            .build();
+    return Preconditions.checkNotNull(

Review Comment:
   ditto



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -142,6 +154,14 @@ enum Kind {
     @Nullable
     abstract Range<Long> getSortedListRange();
 
+    /** For {@link Kind#MULTIMAP_SINGLE_ENTRY} kinds: the key in the multimap 
to fetch or delete. */
+    @Nullable
+    abstract ByteString getMultimapKey();
+
+    /** For {@link Kind#MULTIMAP_ALL} kinds: will return keys only if true. */

Review Comment:
   nit: will only return the keys of the multimap and not the values if true.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> 
Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, 
thus we need to delay
+    // unbatchable multimap requests of the same stateFamily and tag into 
later batches. There's no
+    // priority between get()/entries()/keys(), they will be fetched based on 
the order they occur
+    // in pendingLookups, so that all requests can eventually be fetched and 
none starves.
+
+    Map<String, Map<ByteString, List<StateTag<?>>>> groupedTags =
+        multimapTags.stream()
+            .collect(
+                Collectors.groupingBy(
+                    StateTag::getStateFamily, 
Collectors.groupingBy(StateTag::getTag)));
+
+    for (Map<ByteString, List<StateTag<?>>> familyTags : groupedTags.values()) 
{

Review Comment:
   fast path just one thing for a state family and a tag? no need for all the 
maps and checks in that case



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -679,6 +834,49 @@ private void consumeResponse(Windmill.KeyedGetDataResponse 
response, Set<StateTa
       consumeSortedList(sorted_list, stateTag);
     }
 
+    for (Windmill.TagMultimapFetchResponse tagMultimap : 
response.getTagMultimapsList()) {
+      // First check if it's keys()/entries()

Review Comment:
   can we rely that responses are in the same index as requests?  That isn't 
documented yet but I believe it is true in windmill implementation and we could 
just document it.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();

Review Comment:
   if you combine those other maps it seems you could combine this too. Then 
you end up with a single data structure
   Map<Object, KeyState>
   where KeyState handles caching itself, structural key etc.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in

Review Comment:
   why are these maps keyed by structural key instead of actual key?
   is it for consistent sorting? Either way add a comment



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as 
those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))

Review Comment:
   change type of cachedEntries to avoid casting, so that it is obviously safe



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -129,9 +140,10 @@ enum Kind {
     abstract String getStateFamily();
 
     /**
-     * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX} kinds: A 
previous
-     * 'continuation_position' returned by Windmill to signal the resulting 
bag was incomplete.
-     * Sending that position will request the next page of values. Null for 
first request.
+     * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX, 
KIND#MULTIMAP_SINGLE_ENTRY,
+     * KIND#MULTIMAP_ALL} kinds: A previous 'continuation_position' returned 
by Windmill to signal
+     * the resulting bag was incomplete.Sending that position will request the 
next page of values.

Review Comment:
   nit: rephrase so not bag specific
   
   to signal the result is incomplete and has remaining values.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -334,6 +358,28 @@ public <T> Future<Iterable<TimestampedValue<T>>> 
orderedListFuture(
         valuesToPagingIterableFuture(stateTag, elemCoder, 
this.stateFuture(stateTag, elemCoder)));
   }
 
+  public <T> Future<Iterable<Map.Entry<ByteString, Iterable<T>>>> 
multimapFetchAllFuture(
+      boolean omitValues, ByteString encodedTag, String stateFamily, Coder<T> 
elemCoder) {
+    StateTag<ByteString> stateTag =
+        StateTag.<ByteString>of(Kind.MULTIMAP_ALL, encodedTag, stateFamily)
+            .toBuilder()
+            .setOmitValues(omitValues)
+            .build();
+    return Preconditions.checkNotNull(

Review Comment:
   nit: remove Preconditions check, it's not really a precondition and 
valuesToPagingIterableFuture doesn't return null
   
   You could remove the nullness supression check above and possibly fix up 
other things up so that it is checked statically instead



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> 
Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, 
thus we need to delay
+    // unbatchable multimap requests of the same stateFamily and tag into 
later batches. There's no
+    // priority between get()/entries()/keys(), they will be fetched based on 
the order they occur
+    // in pendingLookups, so that all requests can eventually be fetched and 
none starves.
+
+    Map<String, Map<ByteString, List<StateTag<?>>>> groupedTags =
+        multimapTags.stream()
+            .collect(
+                Collectors.groupingBy(

Review Comment:
   group by statefamily and tag at once?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -569,10 +669,65 @@ private Windmill.KeyedGetDataRequest 
createRequest(Iterable<StateTag<?>> toFetch
           }
           break;
 
+        case MULTIMAP_SINGLE_ENTRY:
+          multimapSingleEntryToFetch.add(stateTag);
+          break;
+
+        case MULTIMAP_ALL:
+          Windmill.TagMultimapFetchRequest.Builder multimapFetchBuilder =
+              keyedDataBuilder
+                  .addMultimapsToFetchBuilder()
+                  .setTag(stateTag.getTag())
+                  .setStateFamily(stateTag.getStateFamily())
+                  .setFetchEntryNamesOnly(stateTag.getOmitValues());
+          if (stateTag.getRequestPosition() == null) {
+            multimapFetchBuilder.setFetchMaxBytes(INITIAL_MAX_MULTIMAP_BYTES);
+          } else {
+            
multimapFetchBuilder.setFetchMaxBytes(CONTINUATION_MAX_MULTIMAP_BYTES);
+            multimapFetchBuilder.setRequestPosition((ByteString) 
stateTag.getRequestPosition());
+            continuation = true;
+          }
+          break;
+
         default:
           throw new RuntimeException("Unknown kind of tag requested: " + 
stateTag.getKind());
       }
     }
+    if (!multimapSingleEntryToFetch.isEmpty()) {
+      Map<String, Map<ByteString, List<StateTag<?>>>> multimapTags =
+          multimapSingleEntryToFetch.stream()
+              .collect(
+                  Collectors.groupingBy(

Review Comment:
   single group?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java:
##########
@@ -466,20 +518,67 @@ private <ResultT, ContinuationT> 
Future<Iterable<ResultT>> valuesToPagingIterabl
     return Futures.lazyTransform(future, toIterable);
   }
 
+  private void delayUnbatchableMultimapFetches(
+      List<StateTag<?>> multimapTags, HashSet<StateTag<?>> toFetch) {
+    // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest, 
thus we need to delay

Review Comment:
   the proto supports repeated TagMultimapFetchRequest, why is it limited to 1?
   
   Did you mean it supports at most 1 fetch for a tag/state_family?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();

Review Comment:
   Can you combine cachedEntries, existKeyCache, nonexistentKeyCache?
   Seems like you could store Iterable if values cached, null if key known but 
values not, and emptylist if key doesn't exist (ie no entries).



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);

Review Comment:
   final?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();

Review Comment:
   maybe we should cache if complete?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {

Review Comment:
   See above, I think this would be simpler with single key map to iterate over.



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());

Review Comment:
   it seems this insert should always put the entry if you check if it contains 
the key in 1708.  In that case can you just put persistedValues in the map?



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {

Review Comment:
   I think the size call is going to iterate through large maps, paginating 
from windmill.  Use Iterables.isEmpty



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as 
those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> 
!nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, 
structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, 
structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.filter(
+                          keys, e -> 
!existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());
+        map.get(key).extendWith(iterable);
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.keySet(),

Review Comment:
   if you went over entrySet it seems like you could avoid get lookup below



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as 
those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> 
!nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, 
structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, 
structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.filter(
+                          keys, e -> 
!existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());
+        map.get(key).extendWith(iterable);
+      }
+
+      @Override
+      public Iterator<Entry<K, V>> iterator() {
+        return Iterators.concat(
+            Iterables.transform(
+                    map.keySet(),
+                    k ->
+                        Iterables.transform(map.get(k), v -> new 
AbstractMap.SimpleEntry<>(k, v))
+                            .iterator())
+                .iterator());
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<Entry<K, V>>> entries() {
+      return new ReadableState<Iterable<Entry<K, V>>>() {
+        @Override
+        public Iterable<Entry<K, V>> read() {
+          if (complete) {
+            return Iterables.unmodifiableIterable(mergedCachedEntries());
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(false);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Map<Object, ConcatIterables<V>> entryMap = Maps.newHashMap();
+            entries.forEach(
+                entry -> {
+                  try {
+                    K key = keyCoder.decode(entry.getKey().newInput());
+                    Object structuralKey = keyCoder.structuralValue(key);
+                    structuralKeyMapping.put(structuralKey, key);
+                    if (nonexistentKeyCache.contains(structuralKey)) return;
+                    if (entryMap.containsKey(structuralKey)) {

Review Comment:
   use merge or update



##########
runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java:
##########
@@ -1588,7 +1599,472 @@ private Future<Iterable<Map.Entry<ByteString, V>>> 
getFuture() {
         return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, 
valueCoder);
       }
     }
-  };
+  }
+
+  private static class WindmillMultimap<K, V> extends SimpleWindmillState
+      implements MultimapState<K, V> {
+
+    private final StateNamespace namespace;
+    private final StateTag<MultimapState<K, V>> address;
+    private final ByteString stateKey;
+    private final String stateFamily;
+    private final Coder<K> keyCoder;
+    private final Coder<V> valueCoder;
+
+    private boolean cleared = false;
+    /**
+     * For any given key, if it's contained in {@link #cachedEntries}, then 
the complete content of
+     * this key is cached: persisted values of this key in backing store are 
cached in
+     * cachedEntries, newly added values not yet written to backing store are 
cached in
+     * localAdditions. If a key is not contained in {@link #cachedEntries} 
then we don't know if
+     * Windmill contains additional values which also maps to this key, we'll 
need to read them if
+     * the work item actually wants the content.
+     */
+    private Map<Object, Iterable<V>> cachedEntries = Maps.newHashMap();
+    // Any key presents in existKeyCache is known to exist in the multimap.
+    private Set<Object> existKeyCache = Sets.newHashSet();
+    // If true, any key not in existKeyCache is known to be nonexistent.
+    private boolean allKeysKnown = false;
+    // Any key presents in nonexistentKeyCache is known to be nonexistent in 
the multimap.
+    private Set<Object> nonexistentKeyCache = Sets.newHashSet();
+
+    private boolean complete = false;
+    private Multimap<Object, V> localAdditions = ArrayListMultimap.create();
+    // All keys that are pending delete. If a key exist in both localRemovals 
and localAdditions:
+    // new values in localAdditions will be added after old values are removed.
+    private Set<Object> localRemovals = Sets.newHashSet();
+    // structuralKeyMapping maps from the structuralKeys to the actual keys. 
Any key in
+    // cachedEntries, existKeyCache, nonexistentKeyCache, localAdditions and 
localRemovals should be
+    // included in this mapping.
+    private Map<Object, K> structuralKeyMapping = Maps.newHashMap();
+
+    private WindmillMultimap(
+        StateNamespace namespace,
+        StateTag<MultimapState<K, V>> address,
+        String stateFamily,
+        Coder<K> keyCoder,
+        Coder<V> valueCoder,
+        boolean isNewShardingKey) {
+      this.namespace = namespace;
+      this.address = address;
+      this.stateKey = encodeKey(namespace, address);
+      this.stateFamily = stateFamily;
+      this.keyCoder = keyCoder;
+      this.valueCoder = valueCoder;
+      this.complete = isNewShardingKey;
+      this.allKeysKnown = isNewShardingKey;
+    }
+
+    @Override
+    public void put(K key, V value) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      localAdditions.put(structuralKey, value);
+      existKeyCache.add(structuralKey);
+      nonexistentKeyCache.remove(structuralKey);
+      structuralKeyMapping.put(structuralKey, key);
+    }
+
+    // Initiates a backend state read to fetch all entries if necessary.
+    private Future<Iterable<Map.Entry<ByteString, Iterable<V>>>> 
getFuture(boolean omitValues) {
+      if (complete) {
+        return Futures.immediateFuture(Collections.emptyList());
+      } else {
+        return reader.multimapFetchAllFuture(omitValues, stateKey, 
stateFamily, valueCoder);
+      }
+    }
+
+    // Initiates a backend state read to fetch a single entry if necessary.
+    private Future<Iterable<V>> getFutureForKey(K key) {
+      try {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(key, keyStream);
+        return reader.multimapFetchSingleEntryFuture(
+            keyStream.toByteString(), stateKey, stateFamily, valueCoder);
+      } catch (IOException e) {
+        throw new RuntimeException(e);
+      }
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> get(K key) {
+      return new ReadableState<Iterable<V>>() {
+        Object structuralKey = keyCoder.structuralValue(key);
+
+        @Override
+        public Iterable<V> read() {
+          if (nonexistentKeyCache.contains(structuralKey)
+              || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+            return Collections.emptyList();
+          }
+          if (localRemovals.contains(structuralKey)) {
+            // this key has been removed locally but the removal hasn't been 
sent to windmill,
+            // thus values in windmill(if any) are obsolete, and we only care 
about local values.
+            if (localAdditions.containsKey(structuralKey)) {
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            } else {
+              return Collections.emptyList();
+            }
+          }
+          if (cachedEntries.containsKey(structuralKey) || complete) {
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(
+                    cachedEntries.getOrDefault(structuralKey, 
Collections.emptyList()),
+                    localAdditions.get(structuralKey)));
+          }
+          Future<Iterable<V>> persistedData = getFutureForKey(key);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<V> persistedValues = persistedData.get();
+            if (Iterables.size(persistedValues) == 0) {
+              if (!localAdditions.containsKey(structuralKey)) {
+                nonexistentKeyCache.add(structuralKey);
+                Preconditions.checkState(
+                    !existKeyCache.contains(structuralKey),
+                    "Key "
+                        + key
+                        + " exists"
+                        + " in existKeyCache but no value in neither windmill 
nor local additions.");
+              }
+              return 
Iterables.unmodifiableIterable(localAdditions.get(structuralKey));
+            }
+            if (persistedValues instanceof Weighted) {
+              cachedEntries.put(structuralKey, new ConcatIterables<>());
+              ((ConcatIterables<V>) 
cachedEntries.get(structuralKey)).extendWith(persistedValues);
+            }
+            return Iterables.unmodifiableIterable(
+                Iterables.concat(persistedValues, 
localAdditions.get(structuralKey)));
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read Multimap state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<V>> readLater() {
+          WindmillMultimap.this.getFutureForKey(key);
+          return this;
+        }
+      };
+    }
+
+    @Override
+    protected WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+        throws IOException {
+      if (!cleared && localAdditions.isEmpty() && localRemovals.isEmpty()) {
+        return WorkItemCommitRequest.newBuilder().buildPartial();
+      }
+      WorkItemCommitRequest.Builder commitBuilder = 
WorkItemCommitRequest.newBuilder();
+      Windmill.TagMultimapUpdateRequest.Builder builder = null;
+      if (cleared) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+        builder.setDeleteAll(true);
+        cleared = false;
+      }
+      Set<Object> keysWithUpdates = Sets.newHashSet();
+      keysWithUpdates.addAll(localRemovals);
+      keysWithUpdates.addAll(localAdditions.keySet());
+      if (!keysWithUpdates.isEmpty() && builder == null) {
+        builder = commitBuilder.addMultimapUpdatesBuilder();
+      }
+      for (Object structuralKey : keysWithUpdates) {
+        ByteStringOutputStream keyStream = new ByteStringOutputStream();
+        keyCoder.encode(structuralKeyMapping.get(structuralKey), keyStream);
+        ByteString encodedKey = keyStream.toByteString();
+        Windmill.TagMultimapEntry.Builder entryBuilder = 
builder.addUpdatesBuilder();
+        entryBuilder.setEntryName(encodedKey);
+        entryBuilder.setDeleteAll(localRemovals.contains(structuralKey));
+        for (V value : localAdditions.get(structuralKey)) {
+          ByteStringOutputStream valueStream = new ByteStringOutputStream();
+          valueCoder.encode(value, valueStream);
+          ByteString encodedValue = valueStream.toByteString();
+          entryBuilder.addValues(encodedValue);
+        }
+        if (cachedEntries.containsKey(structuralKey)) {
+          // Move newly added values from localAdditions to cachedEntries as 
those new values are
+          // also persisted in Windmill.
+          ((ConcatIterables<V>) cachedEntries.get(structuralKey))
+              .extendWith(localAdditions.get(structuralKey));
+        }
+      }
+
+      if (builder != null) {
+        builder.setTag(stateKey).setStateFamily(stateFamily);
+      }
+      for (Object removedKey : localRemovals) {
+        if (!nonexistentKeyCache.contains(removedKey)) {
+          structuralKeyMapping.remove(removedKey);
+        }
+      }
+      localRemovals = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+
+      cache.put(namespace, address, this, 1);
+
+      return commitBuilder.buildPartial();
+    }
+
+    @Override
+    public void remove(K key) {
+      Object structuralKey = keyCoder.structuralValue(key);
+      if (!structuralKeyMapping.containsKey(structuralKey)) {
+        structuralKeyMapping.put(structuralKey, key);
+      }
+      if (nonexistentKeyCache.contains(structuralKey)
+          || (allKeysKnown && !existKeyCache.contains(structuralKey))) {
+        return;
+      }
+      if (cachedEntries.containsKey(structuralKey) || !complete) {
+        // there may be data in windmill that need to be removed.
+        localRemovals.add(structuralKey);
+        cachedEntries.remove(structuralKey);
+      } // else: no data in windmill, deleting from local cache is sufficient.
+      localAdditions.removeAll(structuralKey);
+      existKeyCache.remove(structuralKey);
+      nonexistentKeyCache.add(structuralKey);
+    }
+
+    @Override
+    public void clear() {
+      cachedEntries = Maps.newHashMap();
+      existKeyCache = Sets.newHashSet();
+      nonexistentKeyCache = Sets.newHashSet();
+      localAdditions = ArrayListMultimap.create();
+      localRemovals = Sets.newHashSet();
+      structuralKeyMapping = Maps.newHashMap();
+      cleared = true;
+      complete = true;
+      allKeysKnown = true;
+    }
+
+    @Override
+    public ReadableState<Iterable<K>> keys() {
+      return new ReadableState<Iterable<K>>() {
+        @Override
+        public Iterable<K> read() {
+          if (allKeysKnown) {
+            return Iterables.unmodifiableIterable(
+                Iterables.transform(existKeyCache, structuralKeyMapping::get));
+          }
+          Future<Iterable<Entry<ByteString, Iterable<V>>>> persistedData = 
getFuture(true);
+          try (Closeable scope = scopedReadState()) {
+            Iterable<Entry<ByteString, Iterable<V>>> entries = 
persistedData.get();
+            Iterable<K> keys =
+                Iterables.transform(
+                    entries,
+                    entry -> {
+                      try {
+                        return keyCoder.decode(entry.getKey().newInput());
+                      } catch (IOException e) {
+                        throw new RuntimeException(e);
+                      }
+                    });
+            keys =
+                Iterables.filter(
+                    keys, key -> 
!nonexistentKeyCache.contains(keyCoder.structuralValue(key)));
+            if (entries instanceof Weighted) {
+              // This is a known amount of data, cache them all.
+              keys.forEach(
+                  k -> {
+                    Object structuralKey = keyCoder.structuralValue(k);
+                    existKeyCache.add(structuralKey);
+                    structuralKeyMapping.put(structuralKey, k);
+                  });
+              allKeysKnown = true;
+              nonexistentKeyCache = Sets.newHashSet();
+              return Iterables.unmodifiableIterable(
+                  Iterables.transform(existKeyCache, 
structuralKeyMapping::get));
+            } else {
+              return Iterables.unmodifiableIterable(
+                  Iterables.concat(
+                      // This is the part of keys that are cached.
+                      Iterables.transform(existKeyCache, 
structuralKeyMapping::get),
+                      // This is the part of the keys returned from Windmill 
that are not cached.
+                      Iterables.filter(
+                          keys, e -> 
!existKeyCache.contains(keyCoder.structuralValue(e)))));
+            }
+          } catch (InterruptedException | ExecutionException | IOException e) {
+            if (e instanceof InterruptedException) {
+              Thread.currentThread().interrupt();
+            }
+            throw new RuntimeException("Unable to read state", e);
+          }
+        }
+
+        @Override
+        @SuppressWarnings("FutureReturnValueIgnored")
+        public ReadableState<Iterable<K>> readLater() {
+          WindmillMultimap.this.getFuture(true);
+          return this;
+        }
+      };
+    }
+
+    private MultimapIterables<K, V> mergedCachedEntries() {
+      MultimapIterables<K, V> result = new MultimapIterables<>();
+      for (Entry<Object, Collection<V>> entry : 
localAdditions.asMap().entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      for (Entry<Object, Iterable<V>> entry : cachedEntries.entrySet()) {
+        K key = structuralKeyMapping.get(entry.getKey());
+        result.extendWith(key, entry.getValue());
+      }
+      return result;
+    }
+
+    private static class MultimapIterables<K, V> implements Iterable<Entry<K, 
V>> {
+      Map<K, ConcatIterables<V>> map;
+
+      public MultimapIterables() {
+        this.map = new HashMap<>();
+      }
+
+      public void extendWith(K key, Iterable<V> iterable) {
+        if (!map.containsKey(key)) map.put(key, new ConcatIterables<>());

Review Comment:
   use map.merge or map.update instead of multiple operations on same key



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to