This is an automated email from the ASF dual-hosted git repository.

scwhittle pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new c08afeae60d Enable MapState and SetState for dataflow streaming engine 
pipelines with legacy runner by building on top of MultimapState. (#31453)
c08afeae60d is described below

commit c08afeae60dfb1a15a0f4c8669085662a847249f
Author: Sam Whittle <scwhit...@users.noreply.github.com>
AuthorDate: Thu Jul 4 22:22:21 2024 +0200

    Enable MapState and SetState for dataflow streaming engine pipelines with 
legacy runner by building on top of MultimapState. (#31453)
---
 CHANGES.md                                         |   1 +
 .../org/apache/beam/runners/core/StateTags.java    |   8 +
 .../beam/runners/dataflow/DataflowRunner.java      |  35 +---
 .../beam/runners/dataflow/DataflowRunnerTest.java  |  59 ------
 .../dataflow/worker/StreamingDataflowWorker.java   |  11 +-
 .../worker/windmill/state/AbstractWindmillMap.java |  23 +++
 .../worker/windmill/state/CachingStateTable.java   |  53 +++--
 .../worker/windmill/state/WindmillMap.java         |  24 +--
 .../windmill/state/WindmillMapViaMultimap.java     | 164 +++++++++++++++
 .../worker/windmill/state/WindmillMultimap.java    |   4 +-
 .../worker/windmill/state/WindmillSet.java         |  36 +---
 .../worker/windmill/state/WindmillStateCache.java  |  46 +++--
 .../windmill/state/WindmillStateInternals.java     |  14 +-
 .../worker/StreamingModeExecutionContextTest.java  |   5 +-
 .../dataflow/worker/WindmillStateTestUtils.java    |   2 +-
 .../dataflow/worker/WorkerCustomSourcesTest.java   |   5 +-
 .../windmill/state/WindmillStateCacheTest.java     |   2 +-
 .../windmill/state/WindmillStateInternalsTest.java | 225 ++++++++++++++++++++-
 .../refresh/DispatchedActiveWorkRefresherTest.java |   2 +-
 .../java/org/apache/beam/sdk/state/StateSpecs.java |  23 +++
 .../org/apache/beam/sdk/transforms/ParDoTest.java  |  28 ++-
 21 files changed, 573 insertions(+), 197 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index 38fa6e44b73..0a620038f11 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -68,6 +68,7 @@
 
 * Multiple RunInference instances can now share the same model instance by 
setting the model_identifier parameter (Python) 
([#31665](https://github.com/apache/beam/issues/31665)).
 * Removed a 3rd party LGPL dependency from the Go SDK 
([#31765](https://github.com/apache/beam/issues/31765)).
+* Support for MapState and SetState when using Dataflow Runner v1 with 
Streaming Engine (Java) 
([[#18200](https://github.com/apache/beam/issues/18200)])
 
 ## Breaking Changes
 
diff --git 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
index 7ffb10c85c0..6ed7f8525fd 100644
--- 
a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
+++ 
b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java
@@ -257,6 +257,14 @@ public class StateTags {
         new StructuredId(setTag.getId()), 
StateSpecs.convertToMapSpecInternal(setTag.getSpec()));
   }
 
+  public static <KeyT, ValueT> StateTag<MultimapState<KeyT, ValueT>> 
convertToMultiMapTagInternal(
+      StateTag<MapState<KeyT, ValueT>> mapTag) {
+    StateSpec<MapState<KeyT, ValueT>> spec = mapTag.getSpec();
+    StateSpec<MultimapState<KeyT, ValueT>> multimapSpec =
+        StateSpecs.convertToMultimapSpecInternal(spec);
+    return new SimpleStateTag<>(new StructuredId(mapTag.getId()), 
multimapSpec);
+  }
+
   private static class StructuredId implements Serializable {
     private final StateKind kind;
     private final String rawId;
diff --git 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
index de566599bf8..708c6341326 100644
--- 
a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
+++ 
b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java
@@ -2564,11 +2564,6 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
         || hasExperiment(options, "use_portable_job_submission");
   }
 
-  static boolean useStreamingEngine(DataflowPipelineOptions options) {
-    return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT)
-        || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT);
-  }
-
   static void verifyDoFnSupported(
       DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) {
     if (!streaming && DoFnSignatures.usesMultimapState(fn)) {
@@ -2583,8 +2578,6 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
               "%s does not currently support @RequiresTimeSortedInput in 
streaming mode.",
               DataflowRunner.class.getSimpleName()));
     }
-
-    boolean streamingEngine = useStreamingEngine(options);
     boolean isUnifiedWorker = useUnifiedWorker(options);
 
     if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) {
@@ -2593,25 +2586,17 @@ public class DataflowRunner extends 
PipelineRunner<DataflowPipelineJob> {
               "%s does not currently support %s running using streaming on 
unified worker",
               DataflowRunner.class.getSimpleName(), 
MultimapState.class.getSimpleName()));
     }
-    if (DoFnSignatures.usesSetState(fn)) {
-      if (streaming && (isUnifiedWorker || streamingEngine)) {
-        throw new UnsupportedOperationException(
-            String.format(
-                "%s does not currently support %s when using %s",
-                DataflowRunner.class.getSimpleName(),
-                SetState.class.getSimpleName(),
-                isUnifiedWorker ? "streaming on unified worker" : "streaming 
engine"));
-      }
+    if (DoFnSignatures.usesSetState(fn) && streaming && isUnifiedWorker) {
+      throw new UnsupportedOperationException(
+          String.format(
+              "%s does not currently support %s when using streaming on 
unified worker",
+              DataflowRunner.class.getSimpleName(), 
SetState.class.getSimpleName()));
     }
-    if (DoFnSignatures.usesMapState(fn)) {
-      if (streaming && (isUnifiedWorker || streamingEngine)) {
-        throw new UnsupportedOperationException(
-            String.format(
-                "%s does not currently support %s when using %s",
-                DataflowRunner.class.getSimpleName(),
-                MapState.class.getSimpleName(),
-                isUnifiedWorker ? "streaming on unified worker" : "streaming 
engine"));
-      }
+    if (DoFnSignatures.usesMapState(fn) && streaming && isUnifiedWorker) {
+      throw new UnsupportedOperationException(
+          String.format(
+              "%s does not currently support %s when using streaming on 
unified worker",
+              DataflowRunner.class.getSimpleName(), 
MapState.class.getSimpleName()));
     }
     if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) {
       throw new UnsupportedOperationException(
diff --git 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
index 55bfc44ee62..cf1066e41d2 100644
--- 
a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
+++ 
b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java
@@ -131,8 +131,6 @@ import org.apache.beam.sdk.runners.AppliedPTransform;
 import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
 import org.apache.beam.sdk.runners.TransformHierarchy;
 import org.apache.beam.sdk.runners.TransformHierarchy.Node;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.SetState;
 import org.apache.beam.sdk.state.StateSpec;
 import org.apache.beam.sdk.state.StateSpecs;
 import org.apache.beam.sdk.state.ValueState;
@@ -1880,63 +1878,6 @@ public class DataflowRunnerTest implements Serializable {
     }
   }
 
-  private void verifyMapStateUnsupported(PipelineOptions options) throws 
Exception {
-    Pipeline p = Pipeline.create(options);
-    p.apply(Create.of(KV.of(13, 42)))
-        .apply(
-            ParDo.of(
-                new DoFn<KV<Integer, Integer>, Void>() {
-
-                  @StateId("fizzle")
-                  private final StateSpec<MapState<Void, Void>> voidState = 
StateSpecs.map();
-
-                  @ProcessElement
-                  public void process() {}
-                }));
-
-    thrown.expectMessage("MapState");
-    thrown.expect(UnsupportedOperationException.class);
-    p.run();
-  }
-
-  @Test
-  public void testMapStateUnsupportedStreamingEngine() throws Exception {
-    PipelineOptions options = buildPipelineOptions();
-    ExperimentalOptions.addExperiment(
-        options.as(ExperimentalOptions.class), 
GcpOptions.STREAMING_ENGINE_EXPERIMENT);
-    options.as(DataflowPipelineOptions.class).setStreaming(true);
-
-    verifyMapStateUnsupported(options);
-  }
-
-  private void verifySetStateUnsupported(PipelineOptions options) throws 
Exception {
-    Pipeline p = Pipeline.create(options);
-    p.apply(Create.of(KV.of(13, 42)))
-        .apply(
-            ParDo.of(
-                new DoFn<KV<Integer, Integer>, Void>() {
-
-                  @StateId("fizzle")
-                  private final StateSpec<SetState<Void>> voidState = 
StateSpecs.set();
-
-                  @ProcessElement
-                  public void process() {}
-                }));
-
-    thrown.expectMessage("SetState");
-    thrown.expect(UnsupportedOperationException.class);
-    p.run();
-  }
-
-  @Test
-  public void testSetStateUnsupportedStreamingEngine() throws Exception {
-    PipelineOptions options = buildPipelineOptions();
-    ExperimentalOptions.addExperiment(
-        options.as(ExperimentalOptions.class), 
GcpOptions.STREAMING_ENGINE_EXPERIMENT);
-    options.as(DataflowPipelineOptions.class).setStreaming(true);
-    verifySetStateUnsupported(options);
-  }
-
   /** Records all the composite transforms visited within the Pipeline. */
   private static class CompositeTransformRecorder extends 
PipelineVisitor.Defaults {
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
index 59819db88a0..0e46e7e4687 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java
@@ -324,7 +324,10 @@ public class StreamingDataflowWorker {
     BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
     AtomicInteger maxWorkItemCommitBytes = new 
AtomicInteger(Integer.MAX_VALUE);
     WindmillStateCache windmillStateCache =
-        WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+        WindmillStateCache.builder()
+            .setSizeMb(options.getWorkerCacheMb())
+            .setSupportMapViaMultimap(options.isEnableStreamingEngine())
+            .build();
     Function<String, ScheduledExecutorService> executorSupplier =
         threadName ->
             Executors.newSingleThreadScheduledExecutor(
@@ -478,7 +481,11 @@ public class StreamingDataflowWorker {
     ConcurrentMap<String, StageInfo> stageInfo = new ConcurrentHashMap<>();
     AtomicInteger maxWorkItemCommitBytes = new 
AtomicInteger(maxWorkItemCommitBytesOverrides);
     BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options);
-    WindmillStateCache stateCache = 
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+    WindmillStateCache stateCache =
+        WindmillStateCache.builder()
+            .setSizeMb(options.getWorkerCacheMb())
+            .setSupportMapViaMultimap(options.isEnableStreamingEngine())
+            .build();
     ComputationConfig.Fetcher configFetcher =
         options.isEnableStreamingEngine()
             ? StreamingEngineComputationConfigFetcher.forTesting(
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
new file mode 100644
index 00000000000..e144d5cf8c3
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.state;
+
+import org.apache.beam.sdk.state.MapState;
+
+public abstract class AbstractWindmillMap<K, V> extends SimpleWindmillState
+    implements MapState<K, V> {}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
index bcaf8bf21a2..c026aac4f96 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java
@@ -24,17 +24,9 @@ import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateTable;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTags;
+import org.apache.beam.sdk.coders.BooleanCoder;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.state.BagState;
-import org.apache.beam.sdk.state.CombiningState;
-import org.apache.beam.sdk.state.MapState;
-import org.apache.beam.sdk.state.MultimapState;
-import org.apache.beam.sdk.state.OrderedListState;
-import org.apache.beam.sdk.state.SetState;
-import org.apache.beam.sdk.state.State;
-import org.apache.beam.sdk.state.StateContext;
-import org.apache.beam.sdk.state.ValueState;
-import org.apache.beam.sdk.state.WatermarkHoldState;
+import org.apache.beam.sdk.state.*;
 import org.apache.beam.sdk.transforms.Combine;
 import org.apache.beam.sdk.transforms.CombineWithContext;
 import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
@@ -50,6 +42,7 @@ final class CachingStateTable extends StateTable {
   private final Supplier<Closeable> scopedReadStateSupplier;
   private final @Nullable StateTable derivedStateTable;
   private final boolean isNewKey;
+  private final boolean mapStateViaMultimapState;
 
   private CachingStateTable(Builder builder) {
     this.stateFamily = builder.stateFamily;
@@ -59,6 +52,7 @@ final class CachingStateTable extends StateTable {
     this.isNewKey = builder.isNewKey;
     this.scopedReadStateSupplier = builder.scopedReadStateSupplier;
     this.derivedStateTable = builder.derivedStateTable;
+    this.mapStateViaMultimapState = builder.mapStateViaMultimapState;
 
     if (this.isSystemTable) {
       Preconditions.checkState(derivedStateTable == null);
@@ -103,30 +97,39 @@ final class CachingStateTable extends StateTable {
 
       @Override
       public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> 
elemCoder) {
+        StateTag<MapState<T, Boolean>> internalMapAddress = 
StateTags.convertToMapTagInternal(spec);
         WindmillSet<T> result =
-            new WindmillSet<>(namespace, spec, stateFamily, elemCoder, cache, 
isNewKey);
+            new WindmillSet<>(bindMap(internalMapAddress, elemCoder, 
BooleanCoder.of()));
         result.initializeForWorkItem(reader, scopedReadStateSupplier);
         return result;
       }
 
       @Override
-      public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
+      public <KeyT, ValueT> AbstractWindmillMap<KeyT, ValueT> bindMap(
           StateTag<MapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder, 
Coder<ValueT> valueCoder) {
-        WindmillMap<KeyT, ValueT> result =
-            cache
-                .get(namespace, spec)
-                .map(mapState -> (WindmillMap<KeyT, ValueT>) mapState)
-                .orElseGet(
-                    () ->
-                        new WindmillMap<>(
-                            namespace, spec, stateFamily, keyCoder, 
valueCoder, isNewKey));
-
+        AbstractWindmillMap<KeyT, ValueT> result;
+        if (mapStateViaMultimapState) {
+          StateTag<MultimapState<KeyT, ValueT>> internalMultimapAddress =
+              StateTags.convertToMultiMapTagInternal(spec);
+          result =
+              new WindmillMapViaMultimap<>(
+                  bindMultimap(internalMultimapAddress, keyCoder, valueCoder));
+        } else {
+          result =
+              cache
+                  .get(namespace, spec)
+                  .map(mapState -> (AbstractWindmillMap<KeyT, ValueT>) 
mapState)
+                  .orElseGet(
+                      () ->
+                          new WindmillMap<>(
+                              namespace, spec, stateFamily, keyCoder, 
valueCoder, isNewKey));
+        }
         result.initializeForWorkItem(reader, scopedReadStateSupplier);
         return result;
       }
 
       @Override
-      public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
+      public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap(
           StateTag<MultimapState<KeyT, ValueT>> spec,
           Coder<KeyT> keyCoder,
           Coder<ValueT> valueCoder) {
@@ -246,6 +249,7 @@ final class CachingStateTable extends StateTable {
     private final boolean isNewKey;
     private boolean isSystemTable;
     private @Nullable StateTable derivedStateTable;
+    private boolean mapStateViaMultimapState = false;
 
     private Builder(
         String stateFamily,
@@ -268,6 +272,11 @@ final class CachingStateTable extends StateTable {
       return this;
     }
 
+    Builder withMapStateViaMultimapState() {
+      this.mapStateViaMultimapState = true;
+      return this;
+    }
+
     CachingStateTable build() {
       return new CachingStateTable(this);
     }
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
index 9f027af0a87..aed03f33e6d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java
@@ -21,10 +21,7 @@ import static 
org.apache.beam.runners.dataflow.worker.windmill.state.WindmillSta
 
 import java.io.Closeable;
 import java.io.IOException;
-import java.util.AbstractMap;
-import java.util.Collections;
-import java.util.Map;
-import java.util.Set;
+import java.util.*;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.function.Function;
@@ -40,6 +37,8 @@ import org.apache.beam.sdk.util.ByteStringOutputStream;
 import org.apache.beam.sdk.util.Weighted;
 import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps;
 import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
@@ -51,7 +50,7 @@ import 
org.checkerframework.checker.nullness.qual.UnknownKeyFor;
 @SuppressWarnings({
   "nullness" // TODO(https://github.com/apache/beam/issues/20497)
 })
-public class WindmillMap<K, V> extends SimpleWindmillState implements 
MapState<K, V> {
+public class WindmillMap<K, V> extends AbstractWindmillMap<K, V> {
   private final StateNamespace namespace;
   private final StateTag<MapState<K, V>> address;
   private final ByteString stateKeyPrefix;
@@ -327,7 +326,7 @@ public class WindmillMap<K, V> extends SimpleWindmillState 
implements MapState<K
     @Override
     public Iterable<Map.Entry<K, V>> read() {
       if (complete) {
-        return Iterables.unmodifiableIterable(cachedValues.entrySet());
+        return ImmutableMap.copyOf(cachedValues).entrySet();
       }
       Future<Iterable<Map.Entry<ByteString, V>>> persistedData = getFuture();
       try (Closeable scope = scopedReadState()) {
@@ -352,20 +351,22 @@ public class WindmillMap<K, V> extends 
SimpleWindmillState implements MapState<K
                 cachedValues.putIfAbsent(e.getKey(), e.getValue());
               });
           complete = true;
-          return Iterables.unmodifiableIterable(cachedValues.entrySet());
+          return ImmutableMap.copyOf(cachedValues).entrySet();
         } else {
+          ImmutableMap<K, V> cachedCopy = ImmutableMap.copyOf(cachedValues);
+          ImmutableSet<K> removalCopy = ImmutableSet.copyOf(localRemovals);
           // This means that the result might be too large to cache, so don't 
add it to the
           // local cache. Instead merge the iterables, giving priority to any 
local additions
-          // (represented in cachedValued and localRemovals) that may not have 
been committed
+          // (represented in cachedCopy and removalCopy) that may not have 
been committed
           // yet.
           return Iterables.unmodifiableIterable(
               Iterables.concat(
-                  cachedValues.entrySet(),
+                  cachedCopy.entrySet(),
                   Iterables.filter(
                       transformedData,
                       e ->
-                          !cachedValues.containsKey(e.getKey())
-                              && !localRemovals.contains(e.getKey()))));
+                          !cachedCopy.containsKey(e.getKey())
+                              && !removalCopy.contains(e.getKey()))));
         }
 
       } catch (InterruptedException | ExecutionException | IOException e) {
@@ -428,7 +429,6 @@ public class WindmillMap<K, V> extends SimpleWindmillState 
implements MapState<K
           negativeCache.add(key);
           return defaultValue;
         }
-        // TODO: Don't do this if it was already in cache.
         cachedValues.put(key, persistedValue);
         return persistedValue;
       } catch (InterruptedException | ExecutionException | IOException e) {
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
new file mode 100644
index 00000000000..0ee508a53ba
--- /dev/null
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.runners.dataflow.worker.windmill.state;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.function.Function;
+import javax.annotation.Nullable;
+import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
+import org.apache.beam.sdk.state.ReadableState;
+import org.apache.beam.sdk.state.ReadableStates;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables;
+import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators;
+
+public class WindmillMapViaMultimap<KeyT, ValueT> extends 
AbstractWindmillMap<KeyT, ValueT> {
+  final WindmillMultimap<KeyT, ValueT> multimap;
+
+  WindmillMapViaMultimap(WindmillMultimap<KeyT, ValueT> multimap) {
+    this.multimap = multimap;
+  }
+
+  @Override
+  protected Windmill.WorkItemCommitRequest 
persistDirectly(WindmillStateCache.ForKeyAndFamily cache)
+      throws IOException {
+    return multimap.persistDirectly(cache);
+  }
+
+  @Override
+  void initializeForWorkItem(
+      WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier) 
{
+    super.initializeForWorkItem(reader, scopedReadStateSupplier);
+    multimap.initializeForWorkItem(reader, scopedReadStateSupplier);
+  }
+
+  @Override
+  void cleanupAfterWorkItem() {
+    super.cleanupAfterWorkItem();
+    multimap.cleanupAfterWorkItem();
+  }
+
+  @Override
+  public void put(KeyT key, ValueT value) {
+    multimap.remove(key);
+    multimap.put(key, value);
+  }
+
+  @Override
+  public ReadableState<ValueT> computeIfAbsent(
+      KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) {
+    // Note that computeIfAbsent comments indicate that the read is lazy but 
this matches the
+    // existing eager
+    // behavior of WindmillMap.
+    Iterable<ValueT> existingValues = multimap.get(key).read();
+    if (Iterables.isEmpty(existingValues)) {
+      ValueT inserted = mappingFunction.apply(key);
+      multimap.put(key, inserted);
+      return ReadableStates.immediate(inserted);
+    } else {
+      return 
ReadableStates.immediate(Iterables.getOnlyElement(existingValues));
+    }
+  }
+
+  @Override
+  public void remove(KeyT key) {
+    multimap.remove(key);
+  }
+
+  private static class SingleValueIterableAdaptor<T> implements 
ReadableState<T> {
+    final ReadableState<Iterable<T>> wrapped;
+    final @Nullable T defaultValue;
+
+    SingleValueIterableAdaptor(ReadableState<Iterable<T>> wrapped, @Nullable T 
defaultValue) {
+      this.wrapped = wrapped;
+      this.defaultValue = defaultValue;
+    }
+
+    @Override
+    public T read() {
+      Iterator<T> iterator = wrapped.read().iterator();
+      if (!iterator.hasNext()) {
+        return null;
+      }
+      return Iterators.getOnlyElement(iterator);
+    }
+
+    @Override
+    public ReadableState<T> readLater() {
+      wrapped.readLater();
+      return this;
+    }
+  }
+
+  @Override
+  public ReadableState<ValueT> get(KeyT key) {
+    return getOrDefault(key, null);
+  }
+
+  @Override
+  public ReadableState<ValueT> getOrDefault(KeyT key, @Nullable ValueT 
defaultValue) {
+    return new SingleValueIterableAdaptor<>(multimap.get(key), defaultValue);
+  }
+
+  @Override
+  public ReadableState<Iterable<KeyT>> keys() {
+    return multimap.keys();
+  }
+
+  private static class RemoveKeyAdaptor<K, V> implements 
ReadableState<Iterable<V>> {
+    final ReadableState<Iterable<Map.Entry<K, V>>> wrapped;
+
+    RemoveKeyAdaptor(ReadableState<Iterable<Map.Entry<K, V>>> wrapped) {
+      this.wrapped = wrapped;
+    }
+
+    @Override
+    public Iterable<V> read() {
+      return Iterables.transform(wrapped.read(), Map.Entry::getValue);
+    }
+
+    @Override
+    public ReadableState<Iterable<V>> readLater() {
+      wrapped.readLater();
+      return this;
+    }
+  }
+
+  @Override
+  public ReadableState<Iterable<ValueT>> values() {
+    return new RemoveKeyAdaptor<>(multimap.entries());
+  }
+
+  @Override
+  public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() {
+    return multimap.entries();
+  }
+
+  @Override
+  public ReadableState<Boolean> isEmpty() {
+    return multimap.isEmpty();
+  }
+
+  @Override
+  public void clear() {
+    multimap.clear();
+  }
+}
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
index 75f33e69e0b..19c79a497d4 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java
@@ -216,8 +216,8 @@ public class WindmillMultimap<K, V> extends 
SimpleWindmillState implements Multi
     if (keyState == null || keyState.existence == 
KeyExistence.KNOWN_NONEXISTENT) {
       return;
     }
-    if (keyState.valuesCached && keyState.valuesSize == 0) {
-      // no data in windmill, deleting from local cache is sufficient.
+    if (keyState.valuesCached && keyState.valuesSize == 0 && 
!keyState.removedLocally) {
+      // no data in windmill and no need to keep state, deleting from local 
cache is sufficient.
       keyStateMap.remove(structuralKey);
     } else {
       // there may be data in windmill that need to be removed.
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
index 4afb879e722..ee7e6862c7a 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java
@@ -20,13 +20,7 @@ package 
org.apache.beam.runners.dataflow.worker.windmill.state;
 import java.io.Closeable;
 import java.io.IOException;
 import java.util.Optional;
-import org.apache.beam.runners.core.StateNamespace;
-import org.apache.beam.runners.core.StateTag;
-import org.apache.beam.runners.core.StateTags;
 import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
-import org.apache.beam.sdk.coders.BooleanCoder;
-import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.state.MapState;
 import org.apache.beam.sdk.state.ReadableState;
 import org.apache.beam.sdk.state.SetState;
 import 
org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier;
@@ -35,30 +29,10 @@ import org.checkerframework.checker.nullness.qual.NonNull;
 import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
 
 public class WindmillSet<K> extends SimpleWindmillState implements SetState<K> 
{
-  private final WindmillMap<K, Boolean> windmillMap;
-
-  WindmillSet(
-      StateNamespace namespace,
-      StateTag<SetState<K>> address,
-      String stateFamily,
-      Coder<K> keyCoder,
-      WindmillStateCache.ForKeyAndFamily cache,
-      boolean isNewKey) {
-    StateTag<MapState<K, Boolean>> internalMapAddress = 
StateTags.convertToMapTagInternal(address);
-
-    this.windmillMap =
-        cache
-            .get(namespace, internalMapAddress)
-            .map(map -> (WindmillMap<K, Boolean>) map)
-            .orElseGet(
-                () ->
-                    new WindmillMap<>(
-                        namespace,
-                        internalMapAddress,
-                        stateFamily,
-                        keyCoder,
-                        BooleanCoder.of(),
-                        isNewKey));
+  private final AbstractWindmillMap<K, Boolean> windmillMap;
+
+  WindmillSet(AbstractWindmillMap<K, Boolean> windmillMap) {
+    this.windmillMap = windmillMap;
   }
 
   @Override
@@ -117,11 +91,13 @@ public class WindmillSet<K> extends SimpleWindmillState 
implements SetState<K> {
   @Override
   void initializeForWorkItem(
       WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier) 
{
+    super.initializeForWorkItem(reader, scopedReadStateSupplier);
     windmillMap.initializeForWorkItem(reader, scopedReadStateSupplier);
   }
 
   @Override
   void cleanupAfterWorkItem() {
+    super.cleanupAfterWorkItem();
     windmillMap.cleanupAfterWorkItem();
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
index c6c49134bcb..64eb9dd941b 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java
@@ -17,6 +17,7 @@
  */
 package org.apache.beam.runners.dataflow.worker.windmill.state;
 
+import com.google.auto.value.AutoBuilder;
 import java.io.IOException;
 import java.io.PrintWriter;
 import java.util.HashMap;
@@ -29,9 +30,7 @@ import javax.servlet.http.HttpServletResponse;
 import org.apache.beam.runners.core.StateNamespace;
 import org.apache.beam.runners.core.StateTag;
 import org.apache.beam.runners.core.StateTags;
-import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker;
-import org.apache.beam.runners.dataflow.worker.Weighers;
-import org.apache.beam.runners.dataflow.worker.WindmillComputationKey;
+import org.apache.beam.runners.dataflow.worker.*;
 import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet;
 import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider;
 import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey;
@@ -76,26 +75,33 @@ public class WindmillStateCache implements 
StatusDataProvider {
   // entries inaccessible. They will be evicted through normal cache operation.
   private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex;
   private final long workerCacheBytes; // Copy workerCacheMb and convert to 
bytes.
+  private final boolean supportMapViaMultimap;
 
-  private WindmillStateCache(
-      long workerCacheMb,
-      ConcurrentMap<WindmillComputationKey, ForKey> keyIndex,
-      Cache<StateId, StateCacheEntry> stateCache) {
-    this.workerCacheBytes = workerCacheMb * MEGABYTES;
-    this.stateCache = stateCache;
-    this.keyIndex = keyIndex;
-  }
-
-  public static WindmillStateCache ofSizeMbs(long workerCacheMb) {
-    return new WindmillStateCache(
-        workerCacheMb,
-        new 
MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(),
+  WindmillStateCache(long sizeMb, boolean supportMapViaMultimap) {
+    this.workerCacheBytes = sizeMb * MEGABYTES;
+    this.stateCache =
         CacheBuilder.newBuilder()
-            .maximumWeight(workerCacheMb * MEGABYTES)
+            .maximumWeight(workerCacheBytes)
             .recordStats()
             .weigher(Weighers.weightedKeysAndValues())
             .concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL)
-            .build());
+            .build();
+    this.keyIndex =
+        new 
MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap();
+    this.supportMapViaMultimap = supportMapViaMultimap;
+  }
+
+  @AutoBuilder(ofClass = WindmillStateCache.class)
+  public interface Builder {
+    Builder setSizeMb(long sizeMb);
+
+    Builder setSupportMapViaMultimap(boolean supportMapViaMultimap);
+
+    WindmillStateCache build();
+  }
+
+  public static Builder builder() {
+    return new 
AutoBuilder_WindmillStateCache_Builder().setSupportMapViaMultimap(false);
   }
 
   private EntryStats calculateEntryStats() {
@@ -399,6 +405,10 @@ public class WindmillStateCache implements 
StatusDataProvider {
       return stateFamily;
     }
 
+    public boolean supportMapStateViaMultimapState() {
+      return supportMapViaMultimap;
+    }
+
     public <T extends State> Optional<T> get(StateNamespace namespace, 
StateTag<T> address) {
       @SuppressWarnings("nullness")
       // the mapping function for localCache.computeIfAbsent (i.e 
stateCache.getIfPresent) is
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
index c900228e86b..f757db991fa 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java
@@ -66,13 +66,13 @@ public class WindmillStateInternals<K> implements 
StateInternals {
     this.key = key;
     this.cache = cache;
     this.scopedReadStateSupplier = scopedReadStateSupplier;
-    this.workItemDerivedState =
-        CachingStateTable.builder(stateFamily, reader, cache, isNewKey, 
scopedReadStateSupplier)
-            .build();
-    this.workItemState =
-        CachingStateTable.builder(stateFamily, reader, cache, isNewKey, 
scopedReadStateSupplier)
-            .withDerivedState(workItemDerivedState)
-            .build();
+    CachingStateTable.Builder builder =
+        CachingStateTable.builder(stateFamily, reader, cache, isNewKey, 
scopedReadStateSupplier);
+    if (cache.supportMapStateViaMultimapState()) {
+      builder = builder.withMapStateViaMultimapState();
+    }
+    this.workItemDerivedState = builder.build();
+    this.workItemState = 
builder.withDerivedState(workItemDerivedState).build();
   }
 
   @Override
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
index 2193f20f3fe..6c46bda5acf 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java
@@ -112,7 +112,10 @@ public class StreamingModeExecutionContextTest {
             COMPUTATION_ID,
             new ReaderCache(Duration.standardMinutes(1), 
Executors.newCachedThreadPool()),
             stateNameMap,
-            
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"),
+            WindmillStateCache.builder()
+                .setSizeMb(options.getWorkerCacheMb())
+                .build()
+                .forComputation("comp"),
             StreamingStepMetricsContainer.createRegistry(),
             new DataflowExecutionStateTracker(
                 ExecutionStateSampler.newForTest(),
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
index 17da531d452..8708b9f502d 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java
@@ -66,8 +66,8 @@ public class WindmillStateTestUtils {
 
           boolean accessible = f.isAccessible();
           try {
-            f.setAccessible(true);
             path.add(thisClazz.getName() + "#" + f.getName());
+            f.setAccessible(true);
             assertNoReference(f.get(obj), clazz, path, visited);
           } finally {
             path.remove(path.size() - 1);
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
index 9f97c9835dd..5d8ebd53400 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java
@@ -964,7 +964,10 @@ public class WorkerCustomSourcesTest {
             COMPUTATION_ID,
             new ReaderCache(Duration.standardMinutes(1), Runnable::run),
             /*stateNameMap=*/ ImmutableMap.of(),
-            
WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation(COMPUTATION_ID),
+            WindmillStateCache.builder()
+                .setSizeMb(options.getWorkerCacheMb())
+                .build()
+                .forComputation(COMPUTATION_ID),
             StreamingStepMetricsContainer.createRegistry(),
             new DataflowExecutionStateTracker(
                 ExecutionStateSampler.newForTest(),
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
index 446a34f73de..ce8da106b0c 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java
@@ -148,7 +148,7 @@ public class WindmillStateCacheTest {
   @Before
   public void setUp() {
     options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
-    cache = WindmillStateCache.ofSizeMbs(400);
+    cache = WindmillStateCache.builder().setSizeMb(400).build();
     assertEquals(0, cache.getWeight());
   }
 
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
index a53240d6453..33e47623cd0 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java
@@ -20,11 +20,7 @@ package 
org.apache.beam.runners.dataflow.worker.windmill.state;
 import static 
org.apache.beam.runners.dataflow.worker.DataflowMatchers.ByteStringMatcher.byteStringEq;
 import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis;
 import static org.hamcrest.MatcherAssert.assertThat;
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNull;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Matchers.eq;
@@ -130,7 +126,9 @@ public class WindmillStateInternalsTest {
   @Mock private WindmillStateReader mockReader;
   private WindmillStateInternals<String> underTest;
   private WindmillStateInternals<String> underTestNewKey;
+  private WindmillStateInternals<String> underTestMapViaMultimap;
   private WindmillStateCache cache;
+  private WindmillStateCache cacheViaMultimap;
   @Mock private Supplier<Closeable> readStateSupplier;
 
   private static ByteString key(StateNamespace namespace, String addrId) {
@@ -206,7 +204,12 @@ public class WindmillStateInternalsTest {
   public void setUp() {
     MockitoAnnotations.initMocks(this);
     options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class);
-    cache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb());
+    cache = 
WindmillStateCache.builder().setSizeMb(options.getWorkerCacheMb()).build();
+    cacheViaMultimap =
+        WindmillStateCache.builder()
+            .setSizeMb(options.getWorkerCacheMb())
+            .setSupportMapViaMultimap(true)
+            .build();
     resetUnderTest();
   }
 
@@ -242,6 +245,21 @@ public class WindmillStateInternalsTest {
                     workToken)
                 .forFamily(STATE_FAMILY),
             readStateSupplier);
+    underTestMapViaMultimap =
+        new WindmillStateInternals<String>(
+            "dummyNewKey",
+            STATE_FAMILY,
+            mockReader,
+            false,
+            cacheViaMultimap
+                .forComputation("comp")
+                .forKey(
+                    WindmillComputationKey.create(
+                        "comp", ByteString.copyFrom("dummyNewKey", 
Charsets.UTF_8), 123),
+                    17L,
+                    workToken)
+                .forFamily(STATE_FAMILY),
+            readStateSupplier);
   }
 
   @After
@@ -249,6 +267,7 @@ public class WindmillStateInternalsTest {
     // Make sure no WindmillStateReader (a per-WorkItem object) escapes into 
the cache
     // (a global object).
     WindmillStateTestUtils.assertNoReference(cache, WindmillStateReader.class);
+    WindmillStateTestUtils.assertNoReference(cacheViaMultimap, 
WindmillStateReader.class);
   }
 
   private <T> void waitAndSet(final SettableFuture<T> future, final T value, 
final long millis) {
@@ -741,6 +760,38 @@ public class WindmillStateInternalsTest {
     assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3));
   }
 
+  @Test
+  public void testMapViaMultimapGet() {
+    final String tag = "map";
+    StateTag<MapState<byte[], Integer>> addr =
+        StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MapState<byte[], Integer> mapViaMultiMapState = 
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future1 = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key1, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future1);
+    SettableFuture<Iterable<Integer>> future2 = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key2, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future2);
+
+    ReadableState<Integer> result1 = 
mapViaMultiMapState.get(dup(key1)).readLater();
+    ReadableState<Integer> result2 = 
mapViaMultiMapState.get(dup(key2)).readLater();
+    waitAndSet(future1, Collections.singletonList(1), 30);
+    waitAndSet(future2, Collections.emptyList(), 1);
+    assertEquals(Integer.valueOf(1), result1.read());
+    assertNull(result2.read());
+  }
+
   @Test
   public void testMultimapPutAndGet() {
     final String tag = "multimap";
@@ -761,6 +812,41 @@ public class WindmillStateInternalsTest {
     ReadableState<Iterable<Integer>> result = 
multimapState.get(dup(key)).readLater();
     waitAndSet(future, Arrays.asList(1, 2, 3), 30);
     assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3));
+
+    multimapState.remove(key);
+    multimapState.put(key, 4);
+    multimapState.remove(key);
+    multimapState.put(key, 5);
+    assertThat(result.read(), Matchers.containsInAnyOrder(5));
+    multimapState.clear();
+    assertThat(multimapState.get(key).read(), Matchers.emptyIterable());
+  }
+
+  @Test
+  public void testMapViaMultimapPutAndGet() {
+    final String tag = "map";
+    StateTag<MapState<byte[], Integer>> addr =
+        StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MapState<byte[], Integer> mapViaMultiMapState = 
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    mapViaMultiMapState.put(key, 1);
+    ReadableState<Integer> result = 
mapViaMultiMapState.get(dup(key)).readLater();
+    waitAndSet(future, Collections.singletonList(2), 30);
+    assertEquals(Integer.valueOf(1), result.read());
+
+    mapViaMultiMapState.put(key, 3);
+    assertEquals(Integer.valueOf(3), mapViaMultiMapState.get(key).read());
+    mapViaMultiMapState.clear();
+    assertNull(mapViaMultiMapState.get(key).read());
   }
 
   @Test
@@ -791,6 +877,33 @@ public class WindmillStateInternalsTest {
     assertThat(result2.read(), Matchers.emptyIterable());
   }
 
+  @Test
+  public void testMapViaMultimapRemoveAndGet() {
+    final String tag = "map";
+    StateTag<MapState<byte[], Integer>> addr =
+        StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MapState<byte[], Integer> mapViaMultiMapState = 
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+    final byte[] key = "key".getBytes(StandardCharsets.UTF_8);
+    SettableFuture<Iterable<Integer>> future = SettableFuture.create();
+    when(mockReader.multimapFetchSingleEntryFuture(
+            encodeWithCoder(key, ByteArrayCoder.of()),
+            key(NAMESPACE, tag),
+            STATE_FAMILY,
+            VarIntCoder.of()))
+        .thenReturn(future);
+
+    ReadableState<Integer> result1 = mapViaMultiMapState.get(key).readLater();
+    ReadableState<Integer> result2 = 
mapViaMultiMapState.get(dup(key)).readLater();
+    waitAndSet(future, Collections.singletonList(1), 30);
+
+    assertEquals(Integer.valueOf(1), result1.read());
+
+    mapViaMultiMapState.remove(key);
+    assertNull(mapViaMultiMapState.get(dup(key)).read());
+    assertNull(result2.read());
+  }
+
   @Test
   public void testMultimapRemoveThenPut() {
     final String tag = "multimap";
@@ -1030,6 +1143,64 @@ public class WindmillStateInternalsTest {
     assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
   }
 
+  @Test
+  public void testMapViaMultimapEntriesAndKeysMergeLocalAddRemoveClear() {
+    final String tag = "map";
+    StateTag<MapState<byte[], Integer>> addr =
+        StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MapState<byte[], Integer> mapState = 
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+    final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> 
entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> 
keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        mapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater();
+    waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 3), 
multimapEntry(key2, 4)), 30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), 
multimapEntry(key2)), 30);
+
+    mapState.put(key1, 7);
+    mapState.put(dup(key3), 8);
+    mapState.put(key4, 1);
+    mapState.remove(key4);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(3, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(
+            multimapEntryMatcher(key1, 7),
+            multimapEntryMatcher(key2, 4),
+            multimapEntryMatcher(key3, 8)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertEquals(3, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+    assertFalse(mapState.isEmpty().read());
+
+    mapState.clear();
+    assertTrue(mapState.isEmpty().read());
+    assertTrue(Iterables.isEmpty(mapState.keys().read()));
+    assertTrue(Iterables.isEmpty(mapState.entries().read()));
+
+    // Previously read iterable should still have the same result.
+    assertEquals(3, Iterables.size(keys));
+    assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3));
+  }
+
   @Test
   public void testMultimapEntriesAndKeysMergeLocalRemove() {
     final String tag = "multimap";
@@ -1080,6 +1251,48 @@ public class WindmillStateInternalsTest {
     assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
   }
 
+  @Test
+  public void testMapViaMultimapEntriesAndKeysMergeLocalRemove() {
+    final String tag = "map";
+    StateTag<MapState<byte[], Integer>> addr =
+        StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of());
+    MapState<byte[], Integer> mapState = 
underTestMapViaMultimap.state(NAMESPACE, addr);
+
+    final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8);
+    final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8);
+    final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8);
+
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> 
entriesFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(entriesFuture);
+    SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> 
keysFuture =
+        SettableFuture.create();
+    when(mockReader.multimapFetchAllFuture(
+            true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of()))
+        .thenReturn(keysFuture);
+
+    ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult =
+        mapState.entries().readLater();
+    ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater();
+    waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 1), 
multimapEntry(key2, 2)), 30);
+    waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), 
multimapEntry(key2)), 30);
+
+    mapState.remove(dup(key1));
+    mapState.put(key2, 8);
+    mapState.put(dup(key3), 9);
+
+    Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read();
+    assertEquals(2, Iterables.size(entries));
+    assertThat(
+        entries,
+        Matchers.containsInAnyOrder(multimapEntryMatcher(key2, 8), 
multimapEntryMatcher(key3, 9)));
+
+    Iterable<byte[]> keys = keysResult.read();
+    assertThat(keys, Matchers.containsInAnyOrder(key2, key3));
+  }
+
   @Test
   public void testMultimapCacheComplete() {
     final String tag = "multimap";
diff --git 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
index 175c8421ff8..13019116767 100644
--- 
a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
+++ 
b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java
@@ -207,7 +207,7 @@ public class DispatchedActiveWorkRefresherTest {
     int stuckCommitDurationMillis = 100;
     Table<ComputationState, ExecutableWork, WindmillStateCache.ForComputation> 
computations =
         HashBasedTable.create();
-    WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100);
+    WindmillStateCache stateCache = 
WindmillStateCache.builder().setSizeMb(100).build();
     ByteString key = ByteString.EMPTY;
     for (int i = 0; i < 5; i++) {
       WindmillStateCache.ForComputation perComputationStateCache =
diff --git 
a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java 
b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
index 942881522cf..df5084ad092 100644
--- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
+++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java
@@ -377,6 +377,25 @@ public class StateSpecs {
     }
   }
 
+  /**
+   * <b><i>For internal use only; no backwards-compatibility 
guarantees.</i></b>
+   *
+   * <p>Convert a set state spec to a map-state spec.
+   */
+  @Internal
+  public static <KeyT, ValueT> StateSpec<MultimapState<KeyT, ValueT>> 
convertToMultimapSpecInternal(
+      StateSpec<MapState<KeyT, ValueT>> spec) {
+    if (spec instanceof MapStateSpec) {
+      // Checked above; conversion to a map spec depends on the provided spec 
being one of those
+      // created via the factory methods in this class.
+      @SuppressWarnings("unchecked")
+      MapStateSpec<KeyT, ValueT> typedSpec = (MapStateSpec<KeyT, ValueT>) spec;
+      return typedSpec.asMultimapSpec();
+    } else {
+      throw new IllegalArgumentException("Unexpected StateSpec " + spec);
+    }
+  }
+
   /**
    * A specification for a state cell holding a settable value of type {@code 
T}.
    *
@@ -768,6 +787,10 @@ public class StateSpecs {
     public int hashCode() {
       return Objects.hash(getClass(), keyCoder, valueCoder);
     }
+
+    private MultimapStateSpec<K, V> asMultimapSpec() {
+      return new MultimapStateSpec<>(this.keyCoder, this.valueCoder);
+    }
   }
 
   private static class MultimapStateSpec<K, V> implements 
StateSpec<MultimapState<K, V>> {
diff --git 
a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java 
b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
index 89dcafbdf94..fb2321328b3 100644
--- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
+++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java
@@ -2709,19 +2709,26 @@ public class ParDoTest implements Serializable {
                 @StateId(countStateId) CombiningState<Integer, int[], Integer> 
count,
                 OutputReceiver<KV<String, Integer>> r) {
               KV<String, Integer> value = element.getValue();
-              ReadableState<Iterable<Entry<String, Integer>>> entriesView = 
state.entries();
               state.put(value.getKey(), value.getValue());
               count.add(1);
+
+              @Nullable Integer max = state.get("max").read();
+              state.put("max", Math.max(max == null ? 0 : max, 
value.getValue()));
               if (count.read() >= 4) {
-                Iterable<Map.Entry<String, Integer>> iterate = 
state.entries().read();
+                assertEquals(Integer.valueOf(97), state.get("a").read());
+
+                Iterable<Map.Entry<String, Integer>> entriesView = 
state.entries().read();
+                Iterable<String> keysView = state.keys().read();
                 // Make sure that the cached Iterable doesn't change when new 
elements are added,
                 // but that cached ReadableState views of the state do change.
                 state.put("BadKey", -1);
-                assertEquals(3, Iterables.size(iterate));
-                assertEquals(4, Iterables.size(entriesView.read()));
-                assertEquals(4, Iterables.size(state.entries().read()));
+                assertEquals(4, Iterables.size(entriesView));
+                assertEquals(4, Iterables.size(keysView));
+                assertEquals(5, Iterables.size(state.entries().read()));
+                assertEquals(5, Iterables.size(state.keys().read()));
+                assertEquals(Integer.valueOf(97), state.get("max").read());
 
-                for (Map.Entry<String, Integer> entry : iterate) {
+                for (Map.Entry<String, Integer> entry : entriesView) {
                   r.output(KV.of(entry.getKey(), entry.getValue()));
                 }
               }
@@ -2732,11 +2739,14 @@ public class ParDoTest implements Serializable {
           pipeline
               .apply(
                   Create.of(
-                      KV.of("hello", KV.of("a", 97)), KV.of("hello", 
KV.of("b", 42)),
-                      KV.of("hello", KV.of("b", 42)), KV.of("hello", 
KV.of("c", 12))))
+                      KV.of("hello", KV.of("a", 97)),
+                      KV.of("hello", KV.of("b", 42)),
+                      KV.of("hello", KV.of("b", 42)),
+                      KV.of("hello", KV.of("c", 12))))
               .apply(ParDo.of(fn));
 
-      PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), 
KV.of("c", 12));
+      PAssert.that(output)
+          .containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12), 
KV.of("max", 97));
       pipeline.run();
     }
 

Reply via email to