This is an automated email from the ASF dual-hosted git repository. scwhittle pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new c08afeae60d Enable MapState and SetState for dataflow streaming engine pipelines with legacy runner by building on top of MultimapState. (#31453) c08afeae60d is described below commit c08afeae60dfb1a15a0f4c8669085662a847249f Author: Sam Whittle <scwhit...@users.noreply.github.com> AuthorDate: Thu Jul 4 22:22:21 2024 +0200 Enable MapState and SetState for dataflow streaming engine pipelines with legacy runner by building on top of MultimapState. (#31453) --- CHANGES.md | 1 + .../org/apache/beam/runners/core/StateTags.java | 8 + .../beam/runners/dataflow/DataflowRunner.java | 35 +--- .../beam/runners/dataflow/DataflowRunnerTest.java | 59 ------ .../dataflow/worker/StreamingDataflowWorker.java | 11 +- .../worker/windmill/state/AbstractWindmillMap.java | 23 +++ .../worker/windmill/state/CachingStateTable.java | 53 +++-- .../worker/windmill/state/WindmillMap.java | 24 +-- .../windmill/state/WindmillMapViaMultimap.java | 164 +++++++++++++++ .../worker/windmill/state/WindmillMultimap.java | 4 +- .../worker/windmill/state/WindmillSet.java | 36 +--- .../worker/windmill/state/WindmillStateCache.java | 46 +++-- .../windmill/state/WindmillStateInternals.java | 14 +- .../worker/StreamingModeExecutionContextTest.java | 5 +- .../dataflow/worker/WindmillStateTestUtils.java | 2 +- .../dataflow/worker/WorkerCustomSourcesTest.java | 5 +- .../windmill/state/WindmillStateCacheTest.java | 2 +- .../windmill/state/WindmillStateInternalsTest.java | 225 ++++++++++++++++++++- .../refresh/DispatchedActiveWorkRefresherTest.java | 2 +- .../java/org/apache/beam/sdk/state/StateSpecs.java | 23 +++ .../org/apache/beam/sdk/transforms/ParDoTest.java | 28 ++- 21 files changed, 573 insertions(+), 197 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 38fa6e44b73..0a620038f11 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -68,6 +68,7 @@ * Multiple RunInference instances can now share the same model instance by setting the model_identifier parameter (Python) ([#31665](https://github.com/apache/beam/issues/31665)). * Removed a 3rd party LGPL dependency from the Go SDK ([#31765](https://github.com/apache/beam/issues/31765)). +* Support for MapState and SetState when using Dataflow Runner v1 with Streaming Engine (Java) ([[#18200](https://github.com/apache/beam/issues/18200)]) ## Breaking Changes diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index 7ffb10c85c0..6ed7f8525fd 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -257,6 +257,14 @@ public class StateTags { new StructuredId(setTag.getId()), StateSpecs.convertToMapSpecInternal(setTag.getSpec())); } + public static <KeyT, ValueT> StateTag<MultimapState<KeyT, ValueT>> convertToMultiMapTagInternal( + StateTag<MapState<KeyT, ValueT>> mapTag) { + StateSpec<MapState<KeyT, ValueT>> spec = mapTag.getSpec(); + StateSpec<MultimapState<KeyT, ValueT>> multimapSpec = + StateSpecs.convertToMultimapSpecInternal(spec); + return new SimpleStateTag<>(new StructuredId(mapTag.getId()), multimapSpec); + } + private static class StructuredId implements Serializable { private final StateKind kind; private final String rawId; diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index de566599bf8..708c6341326 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2564,11 +2564,6 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { || hasExperiment(options, "use_portable_job_submission"); } - static boolean useStreamingEngine(DataflowPipelineOptions options) { - return hasExperiment(options, GcpOptions.STREAMING_ENGINE_EXPERIMENT) - || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT); - } - static void verifyDoFnSupported( DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) { if (!streaming && DoFnSignatures.usesMultimapState(fn)) { @@ -2583,8 +2578,6 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { "%s does not currently support @RequiresTimeSortedInput in streaming mode.", DataflowRunner.class.getSimpleName())); } - - boolean streamingEngine = useStreamingEngine(options); boolean isUnifiedWorker = useUnifiedWorker(options); if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) { @@ -2593,25 +2586,17 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { "%s does not currently support %s running using streaming on unified worker", DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName())); } - if (DoFnSignatures.usesSetState(fn)) { - if (streaming && (isUnifiedWorker || streamingEngine)) { - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s when using %s", - DataflowRunner.class.getSimpleName(), - SetState.class.getSimpleName(), - isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); - } + if (DoFnSignatures.usesSetState(fn) && streaming && isUnifiedWorker) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming on unified worker", + DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); } - if (DoFnSignatures.usesMapState(fn)) { - if (streaming && (isUnifiedWorker || streamingEngine)) { - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s when using %s", - DataflowRunner.class.getSimpleName(), - MapState.class.getSimpleName(), - isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); - } + if (DoFnSignatures.usesMapState(fn) && streaming && isUnifiedWorker) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s when using streaming on unified worker", + DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); } if (DoFnSignatures.usesBundleFinalizer(fn) && !isUnifiedWorker) { throw new UnsupportedOperationException( diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index 55bfc44ee62..cf1066e41d2 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -131,8 +131,6 @@ import org.apache.beam.sdk.runners.AppliedPTransform; import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; -import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; import org.apache.beam.sdk.state.ValueState; @@ -1880,63 +1878,6 @@ public class DataflowRunnerTest implements Serializable { } } - private void verifyMapStateUnsupported(PipelineOptions options) throws Exception { - Pipeline p = Pipeline.create(options); - p.apply(Create.of(KV.of(13, 42))) - .apply( - ParDo.of( - new DoFn<KV<Integer, Integer>, Void>() { - - @StateId("fizzle") - private final StateSpec<MapState<Void, Void>> voidState = StateSpecs.map(); - - @ProcessElement - public void process() {} - })); - - thrown.expectMessage("MapState"); - thrown.expect(UnsupportedOperationException.class); - p.run(); - } - - @Test - public void testMapStateUnsupportedStreamingEngine() throws Exception { - PipelineOptions options = buildPipelineOptions(); - ExperimentalOptions.addExperiment( - options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); - options.as(DataflowPipelineOptions.class).setStreaming(true); - - verifyMapStateUnsupported(options); - } - - private void verifySetStateUnsupported(PipelineOptions options) throws Exception { - Pipeline p = Pipeline.create(options); - p.apply(Create.of(KV.of(13, 42))) - .apply( - ParDo.of( - new DoFn<KV<Integer, Integer>, Void>() { - - @StateId("fizzle") - private final StateSpec<SetState<Void>> voidState = StateSpecs.set(); - - @ProcessElement - public void process() {} - })); - - thrown.expectMessage("SetState"); - thrown.expect(UnsupportedOperationException.class); - p.run(); - } - - @Test - public void testSetStateUnsupportedStreamingEngine() throws Exception { - PipelineOptions options = buildPipelineOptions(); - ExperimentalOptions.addExperiment( - options.as(ExperimentalOptions.class), GcpOptions.STREAMING_ENGINE_EXPERIMENT); - options.as(DataflowPipelineOptions.class).setStreaming(true); - verifySetStateUnsupported(options); - } - /** Records all the composite transforms visited within the Pipeline. */ private static class CompositeTransformRecorder extends PipelineVisitor.Defaults { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 59819db88a0..0e46e7e4687 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -324,7 +324,10 @@ public class StreamingDataflowWorker { BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(Integer.MAX_VALUE); WindmillStateCache windmillStateCache = - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(options.isEnableStreamingEngine()) + .build(); Function<String, ScheduledExecutorService> executorSupplier = threadName -> Executors.newSingleThreadScheduledExecutor( @@ -478,7 +481,11 @@ public class StreamingDataflowWorker { ConcurrentMap<String, StageInfo> stageInfo = new ConcurrentHashMap<>(); AtomicInteger maxWorkItemCommitBytes = new AtomicInteger(maxWorkItemCommitBytesOverrides); BoundedQueueExecutor workExecutor = createWorkUnitExecutor(options); - WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + WindmillStateCache stateCache = + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(options.isEnableStreamingEngine()) + .build(); ComputationConfig.Fetcher configFetcher = options.isEnableStreamingEngine() ? StreamingEngineComputationConfigFetcher.forTesting( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java new file mode 100644 index 00000000000..e144d5cf8c3 --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/AbstractWindmillMap.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.state; + +import org.apache.beam.sdk.state.MapState; + +public abstract class AbstractWindmillMap<K, V> extends SimpleWindmillState + implements MapState<K, V> {} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java index bcaf8bf21a2..c026aac4f96 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/CachingStateTable.java @@ -24,17 +24,9 @@ import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTable; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; +import org.apache.beam.sdk.coders.BooleanCoder; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.state.BagState; -import org.apache.beam.sdk.state.CombiningState; -import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.MultimapState; -import org.apache.beam.sdk.state.OrderedListState; -import org.apache.beam.sdk.state.SetState; -import org.apache.beam.sdk.state.State; -import org.apache.beam.sdk.state.StateContext; -import org.apache.beam.sdk.state.ValueState; -import org.apache.beam.sdk.state.WatermarkHoldState; +import org.apache.beam.sdk.state.*; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.CombineWithContext; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; @@ -50,6 +42,7 @@ final class CachingStateTable extends StateTable { private final Supplier<Closeable> scopedReadStateSupplier; private final @Nullable StateTable derivedStateTable; private final boolean isNewKey; + private final boolean mapStateViaMultimapState; private CachingStateTable(Builder builder) { this.stateFamily = builder.stateFamily; @@ -59,6 +52,7 @@ final class CachingStateTable extends StateTable { this.isNewKey = builder.isNewKey; this.scopedReadStateSupplier = builder.scopedReadStateSupplier; this.derivedStateTable = builder.derivedStateTable; + this.mapStateViaMultimapState = builder.mapStateViaMultimapState; if (this.isSystemTable) { Preconditions.checkState(derivedStateTable == null); @@ -103,30 +97,39 @@ final class CachingStateTable extends StateTable { @Override public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) { + StateTag<MapState<T, Boolean>> internalMapAddress = StateTags.convertToMapTagInternal(spec); WindmillSet<T> result = - new WindmillSet<>(namespace, spec, stateFamily, elemCoder, cache, isNewKey); + new WindmillSet<>(bindMap(internalMapAddress, elemCoder, BooleanCoder.of())); result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; } @Override - public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + public <KeyT, ValueT> AbstractWindmillMap<KeyT, ValueT> bindMap( StateTag<MapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) { - WindmillMap<KeyT, ValueT> result = - cache - .get(namespace, spec) - .map(mapState -> (WindmillMap<KeyT, ValueT>) mapState) - .orElseGet( - () -> - new WindmillMap<>( - namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey)); - + AbstractWindmillMap<KeyT, ValueT> result; + if (mapStateViaMultimapState) { + StateTag<MultimapState<KeyT, ValueT>> internalMultimapAddress = + StateTags.convertToMultiMapTagInternal(spec); + result = + new WindmillMapViaMultimap<>( + bindMultimap(internalMultimapAddress, keyCoder, valueCoder)); + } else { + result = + cache + .get(namespace, spec) + .map(mapState -> (AbstractWindmillMap<KeyT, ValueT>) mapState) + .orElseGet( + () -> + new WindmillMap<>( + namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey)); + } result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; } @Override - public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap( + public <KeyT, ValueT> WindmillMultimap<KeyT, ValueT> bindMultimap( StateTag<MultimapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) { @@ -246,6 +249,7 @@ final class CachingStateTable extends StateTable { private final boolean isNewKey; private boolean isSystemTable; private @Nullable StateTable derivedStateTable; + private boolean mapStateViaMultimapState = false; private Builder( String stateFamily, @@ -268,6 +272,11 @@ final class CachingStateTable extends StateTable { return this; } + Builder withMapStateViaMultimapState() { + this.mapStateViaMultimapState = true; + return this; + } + CachingStateTable build() { return new CachingStateTable(this); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java index 9f027af0a87..aed03f33e6d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMap.java @@ -21,10 +21,7 @@ import static org.apache.beam.runners.dataflow.worker.windmill.state.WindmillSta import java.io.Closeable; import java.io.IOException; -import java.util.AbstractMap; -import java.util.Collections; -import java.util.Map; -import java.util.Set; +import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.function.Function; @@ -40,6 +37,8 @@ import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.util.Weighted; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; @@ -51,7 +50,7 @@ import org.checkerframework.checker.nullness.qual.UnknownKeyFor; @SuppressWarnings({ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -public class WindmillMap<K, V> extends SimpleWindmillState implements MapState<K, V> { +public class WindmillMap<K, V> extends AbstractWindmillMap<K, V> { private final StateNamespace namespace; private final StateTag<MapState<K, V>> address; private final ByteString stateKeyPrefix; @@ -327,7 +326,7 @@ public class WindmillMap<K, V> extends SimpleWindmillState implements MapState<K @Override public Iterable<Map.Entry<K, V>> read() { if (complete) { - return Iterables.unmodifiableIterable(cachedValues.entrySet()); + return ImmutableMap.copyOf(cachedValues).entrySet(); } Future<Iterable<Map.Entry<ByteString, V>>> persistedData = getFuture(); try (Closeable scope = scopedReadState()) { @@ -352,20 +351,22 @@ public class WindmillMap<K, V> extends SimpleWindmillState implements MapState<K cachedValues.putIfAbsent(e.getKey(), e.getValue()); }); complete = true; - return Iterables.unmodifiableIterable(cachedValues.entrySet()); + return ImmutableMap.copyOf(cachedValues).entrySet(); } else { + ImmutableMap<K, V> cachedCopy = ImmutableMap.copyOf(cachedValues); + ImmutableSet<K> removalCopy = ImmutableSet.copyOf(localRemovals); // This means that the result might be too large to cache, so don't add it to the // local cache. Instead merge the iterables, giving priority to any local additions - // (represented in cachedValued and localRemovals) that may not have been committed + // (represented in cachedCopy and removalCopy) that may not have been committed // yet. return Iterables.unmodifiableIterable( Iterables.concat( - cachedValues.entrySet(), + cachedCopy.entrySet(), Iterables.filter( transformedData, e -> - !cachedValues.containsKey(e.getKey()) - && !localRemovals.contains(e.getKey())))); + !cachedCopy.containsKey(e.getKey()) + && !removalCopy.contains(e.getKey())))); } } catch (InterruptedException | ExecutionException | IOException e) { @@ -428,7 +429,6 @@ public class WindmillMap<K, V> extends SimpleWindmillState implements MapState<K negativeCache.add(key); return defaultValue; } - // TODO: Don't do this if it was already in cache. cachedValues.put(key, persistedValue); return persistedValue; } catch (InterruptedException | ExecutionException | IOException e) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java new file mode 100644 index 00000000000..0ee508a53ba --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMapViaMultimap.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.windmill.state; + +import java.io.Closeable; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; +import java.util.function.Function; +import javax.annotation.Nullable; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterators; + +public class WindmillMapViaMultimap<KeyT, ValueT> extends AbstractWindmillMap<KeyT, ValueT> { + final WindmillMultimap<KeyT, ValueT> multimap; + + WindmillMapViaMultimap(WindmillMultimap<KeyT, ValueT> multimap) { + this.multimap = multimap; + } + + @Override + protected Windmill.WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache) + throws IOException { + return multimap.persistDirectly(cache); + } + + @Override + void initializeForWorkItem( + WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier) { + super.initializeForWorkItem(reader, scopedReadStateSupplier); + multimap.initializeForWorkItem(reader, scopedReadStateSupplier); + } + + @Override + void cleanupAfterWorkItem() { + super.cleanupAfterWorkItem(); + multimap.cleanupAfterWorkItem(); + } + + @Override + public void put(KeyT key, ValueT value) { + multimap.remove(key); + multimap.put(key, value); + } + + @Override + public ReadableState<ValueT> computeIfAbsent( + KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) { + // Note that computeIfAbsent comments indicate that the read is lazy but this matches the + // existing eager + // behavior of WindmillMap. + Iterable<ValueT> existingValues = multimap.get(key).read(); + if (Iterables.isEmpty(existingValues)) { + ValueT inserted = mappingFunction.apply(key); + multimap.put(key, inserted); + return ReadableStates.immediate(inserted); + } else { + return ReadableStates.immediate(Iterables.getOnlyElement(existingValues)); + } + } + + @Override + public void remove(KeyT key) { + multimap.remove(key); + } + + private static class SingleValueIterableAdaptor<T> implements ReadableState<T> { + final ReadableState<Iterable<T>> wrapped; + final @Nullable T defaultValue; + + SingleValueIterableAdaptor(ReadableState<Iterable<T>> wrapped, @Nullable T defaultValue) { + this.wrapped = wrapped; + this.defaultValue = defaultValue; + } + + @Override + public T read() { + Iterator<T> iterator = wrapped.read().iterator(); + if (!iterator.hasNext()) { + return null; + } + return Iterators.getOnlyElement(iterator); + } + + @Override + public ReadableState<T> readLater() { + wrapped.readLater(); + return this; + } + } + + @Override + public ReadableState<ValueT> get(KeyT key) { + return getOrDefault(key, null); + } + + @Override + public ReadableState<ValueT> getOrDefault(KeyT key, @Nullable ValueT defaultValue) { + return new SingleValueIterableAdaptor<>(multimap.get(key), defaultValue); + } + + @Override + public ReadableState<Iterable<KeyT>> keys() { + return multimap.keys(); + } + + private static class RemoveKeyAdaptor<K, V> implements ReadableState<Iterable<V>> { + final ReadableState<Iterable<Map.Entry<K, V>>> wrapped; + + RemoveKeyAdaptor(ReadableState<Iterable<Map.Entry<K, V>>> wrapped) { + this.wrapped = wrapped; + } + + @Override + public Iterable<V> read() { + return Iterables.transform(wrapped.read(), Map.Entry::getValue); + } + + @Override + public ReadableState<Iterable<V>> readLater() { + wrapped.readLater(); + return this; + } + } + + @Override + public ReadableState<Iterable<ValueT>> values() { + return new RemoveKeyAdaptor<>(multimap.entries()); + } + + @Override + public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() { + return multimap.entries(); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return multimap.isEmpty(); + } + + @Override + public void clear() { + multimap.clear(); + } +} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java index 75f33e69e0b..19c79a497d4 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillMultimap.java @@ -216,8 +216,8 @@ public class WindmillMultimap<K, V> extends SimpleWindmillState implements Multi if (keyState == null || keyState.existence == KeyExistence.KNOWN_NONEXISTENT) { return; } - if (keyState.valuesCached && keyState.valuesSize == 0) { - // no data in windmill, deleting from local cache is sufficient. + if (keyState.valuesCached && keyState.valuesSize == 0 && !keyState.removedLocally) { + // no data in windmill and no need to keep state, deleting from local cache is sufficient. keyStateMap.remove(structuralKey); } else { // there may be data in windmill that need to be removed. diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java index 4afb879e722..ee7e6862c7a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillSet.java @@ -20,13 +20,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.state; import java.io.Closeable; import java.io.IOException; import java.util.Optional; -import org.apache.beam.runners.core.StateNamespace; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.sdk.coders.BooleanCoder; -import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.SetState; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; @@ -35,30 +29,10 @@ import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.UnknownKeyFor; public class WindmillSet<K> extends SimpleWindmillState implements SetState<K> { - private final WindmillMap<K, Boolean> windmillMap; - - WindmillSet( - StateNamespace namespace, - StateTag<SetState<K>> address, - String stateFamily, - Coder<K> keyCoder, - WindmillStateCache.ForKeyAndFamily cache, - boolean isNewKey) { - StateTag<MapState<K, Boolean>> internalMapAddress = StateTags.convertToMapTagInternal(address); - - this.windmillMap = - cache - .get(namespace, internalMapAddress) - .map(map -> (WindmillMap<K, Boolean>) map) - .orElseGet( - () -> - new WindmillMap<>( - namespace, - internalMapAddress, - stateFamily, - keyCoder, - BooleanCoder.of(), - isNewKey)); + private final AbstractWindmillMap<K, Boolean> windmillMap; + + WindmillSet(AbstractWindmillMap<K, Boolean> windmillMap) { + this.windmillMap = windmillMap; } @Override @@ -117,11 +91,13 @@ public class WindmillSet<K> extends SimpleWindmillState implements SetState<K> { @Override void initializeForWorkItem( WindmillStateReader reader, Supplier<Closeable> scopedReadStateSupplier) { + super.initializeForWorkItem(reader, scopedReadStateSupplier); windmillMap.initializeForWorkItem(reader, scopedReadStateSupplier); } @Override void cleanupAfterWorkItem() { + super.cleanupAfterWorkItem(); windmillMap.cleanupAfterWorkItem(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java index c6c49134bcb..64eb9dd941b 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCache.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.state; +import com.google.auto.value.AutoBuilder; import java.io.IOException; import java.io.PrintWriter; import java.util.HashMap; @@ -29,9 +30,7 @@ import javax.servlet.http.HttpServletResponse; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; -import org.apache.beam.runners.dataflow.worker.StreamingDataflowWorker; -import org.apache.beam.runners.dataflow.worker.Weighers; -import org.apache.beam.runners.dataflow.worker.WindmillComputationKey; +import org.apache.beam.runners.dataflow.worker.*; import org.apache.beam.runners.dataflow.worker.status.BaseStatusServlet; import org.apache.beam.runners.dataflow.worker.status.StatusDataProvider; import org.apache.beam.runners.dataflow.worker.streaming.ShardedKey; @@ -76,26 +75,33 @@ public class WindmillStateCache implements StatusDataProvider { // entries inaccessible. They will be evicted through normal cache operation. private final ConcurrentMap<WindmillComputationKey, ForKey> keyIndex; private final long workerCacheBytes; // Copy workerCacheMb and convert to bytes. + private final boolean supportMapViaMultimap; - private WindmillStateCache( - long workerCacheMb, - ConcurrentMap<WindmillComputationKey, ForKey> keyIndex, - Cache<StateId, StateCacheEntry> stateCache) { - this.workerCacheBytes = workerCacheMb * MEGABYTES; - this.stateCache = stateCache; - this.keyIndex = keyIndex; - } - - public static WindmillStateCache ofSizeMbs(long workerCacheMb) { - return new WindmillStateCache( - workerCacheMb, - new MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(), + WindmillStateCache(long sizeMb, boolean supportMapViaMultimap) { + this.workerCacheBytes = sizeMb * MEGABYTES; + this.stateCache = CacheBuilder.newBuilder() - .maximumWeight(workerCacheMb * MEGABYTES) + .maximumWeight(workerCacheBytes) .recordStats() .weigher(Weighers.weightedKeysAndValues()) .concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL) - .build()); + .build(); + this.keyIndex = + new MapMaker().weakValues().concurrencyLevel(STATE_CACHE_CONCURRENCY_LEVEL).makeMap(); + this.supportMapViaMultimap = supportMapViaMultimap; + } + + @AutoBuilder(ofClass = WindmillStateCache.class) + public interface Builder { + Builder setSizeMb(long sizeMb); + + Builder setSupportMapViaMultimap(boolean supportMapViaMultimap); + + WindmillStateCache build(); + } + + public static Builder builder() { + return new AutoBuilder_WindmillStateCache_Builder().setSupportMapViaMultimap(false); } private EntryStats calculateEntryStats() { @@ -399,6 +405,10 @@ public class WindmillStateCache implements StatusDataProvider { return stateFamily; } + public boolean supportMapStateViaMultimapState() { + return supportMapViaMultimap; + } + public <T extends State> Optional<T> get(StateNamespace namespace, StateTag<T> address) { @SuppressWarnings("nullness") // the mapping function for localCache.computeIfAbsent (i.e stateCache.getIfPresent) is diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java index c900228e86b..f757db991fa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java @@ -66,13 +66,13 @@ public class WindmillStateInternals<K> implements StateInternals { this.key = key; this.cache = cache; this.scopedReadStateSupplier = scopedReadStateSupplier; - this.workItemDerivedState = - CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier) - .build(); - this.workItemState = - CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier) - .withDerivedState(workItemDerivedState) - .build(); + CachingStateTable.Builder builder = + CachingStateTable.builder(stateFamily, reader, cache, isNewKey, scopedReadStateSupplier); + if (cache.supportMapStateViaMultimapState()) { + builder = builder.withMapStateViaMultimapState(); + } + this.workItemDerivedState = builder.build(); + this.workItemState = builder.withDerivedState(workItemDerivedState).build(); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 2193f20f3fe..6c46bda5acf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -112,7 +112,10 @@ public class StreamingModeExecutionContextTest { COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Executors.newCachedThreadPool()), stateNameMap, - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation("comp"), + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation("comp"), StreamingStepMetricsContainer.createRegistry(), new DataflowExecutionStateTracker( ExecutionStateSampler.newForTest(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java index 17da531d452..8708b9f502d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateTestUtils.java @@ -66,8 +66,8 @@ public class WindmillStateTestUtils { boolean accessible = f.isAccessible(); try { - f.setAccessible(true); path.add(thisClazz.getName() + "#" + f.getName()); + f.setAccessible(true); assertNoReference(f.get(obj), clazz, path, visited); } finally { path.remove(path.size() - 1); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 9f97c9835dd..5d8ebd53400 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -964,7 +964,10 @@ public class WorkerCustomSourcesTest { COMPUTATION_ID, new ReaderCache(Duration.standardMinutes(1), Runnable::run), /*stateNameMap=*/ ImmutableMap.of(), - WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()).forComputation(COMPUTATION_ID), + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .build() + .forComputation(COMPUTATION_ID), StreamingStepMetricsContainer.createRegistry(), new DataflowExecutionStateTracker( ExecutionStateSampler.newForTest(), diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java index 446a34f73de..ce8da106b0c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateCacheTest.java @@ -148,7 +148,7 @@ public class WindmillStateCacheTest { @Before public void setUp() { options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - cache = WindmillStateCache.ofSizeMbs(400); + cache = WindmillStateCache.builder().setSizeMb(400).build(); assertEquals(0, cache.getWeight()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java index a53240d6453..33e47623cd0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternalsTest.java @@ -20,11 +20,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.state; import static org.apache.beam.runners.dataflow.worker.DataflowMatchers.ByteStringMatcher.byteStringEq; import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis; import static org.hamcrest.MatcherAssert.assertThat; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Matchers.eq; @@ -130,7 +126,9 @@ public class WindmillStateInternalsTest { @Mock private WindmillStateReader mockReader; private WindmillStateInternals<String> underTest; private WindmillStateInternals<String> underTestNewKey; + private WindmillStateInternals<String> underTestMapViaMultimap; private WindmillStateCache cache; + private WindmillStateCache cacheViaMultimap; @Mock private Supplier<Closeable> readStateSupplier; private static ByteString key(StateNamespace namespace, String addrId) { @@ -206,7 +204,12 @@ public class WindmillStateInternalsTest { public void setUp() { MockitoAnnotations.initMocks(this); options = PipelineOptionsFactory.as(DataflowWorkerHarnessOptions.class); - cache = WindmillStateCache.ofSizeMbs(options.getWorkerCacheMb()); + cache = WindmillStateCache.builder().setSizeMb(options.getWorkerCacheMb()).build(); + cacheViaMultimap = + WindmillStateCache.builder() + .setSizeMb(options.getWorkerCacheMb()) + .setSupportMapViaMultimap(true) + .build(); resetUnderTest(); } @@ -242,6 +245,21 @@ public class WindmillStateInternalsTest { workToken) .forFamily(STATE_FAMILY), readStateSupplier); + underTestMapViaMultimap = + new WindmillStateInternals<String>( + "dummyNewKey", + STATE_FAMILY, + mockReader, + false, + cacheViaMultimap + .forComputation("comp") + .forKey( + WindmillComputationKey.create( + "comp", ByteString.copyFrom("dummyNewKey", Charsets.UTF_8), 123), + 17L, + workToken) + .forFamily(STATE_FAMILY), + readStateSupplier); } @After @@ -249,6 +267,7 @@ public class WindmillStateInternalsTest { // Make sure no WindmillStateReader (a per-WorkItem object) escapes into the cache // (a global object). WindmillStateTestUtils.assertNoReference(cache, WindmillStateReader.class); + WindmillStateTestUtils.assertNoReference(cacheViaMultimap, WindmillStateReader.class); } private <T> void waitAndSet(final SettableFuture<T> future, final T value, final long millis) { @@ -741,6 +760,38 @@ public class WindmillStateInternalsTest { assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3)); } + @Test + public void testMapViaMultimapGet() { + final String tag = "map"; + StateTag<MapState<byte[], Integer>> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState<byte[], Integer> mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + SettableFuture<Iterable<Integer>> future1 = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future1); + SettableFuture<Iterable<Integer>> future2 = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key2, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future2); + + ReadableState<Integer> result1 = mapViaMultiMapState.get(dup(key1)).readLater(); + ReadableState<Integer> result2 = mapViaMultiMapState.get(dup(key2)).readLater(); + waitAndSet(future1, Collections.singletonList(1), 30); + waitAndSet(future2, Collections.emptyList(), 1); + assertEquals(Integer.valueOf(1), result1.read()); + assertNull(result2.read()); + } + @Test public void testMultimapPutAndGet() { final String tag = "multimap"; @@ -761,6 +812,41 @@ public class WindmillStateInternalsTest { ReadableState<Iterable<Integer>> result = multimapState.get(dup(key)).readLater(); waitAndSet(future, Arrays.asList(1, 2, 3), 30); assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3)); + + multimapState.remove(key); + multimapState.put(key, 4); + multimapState.remove(key); + multimapState.put(key, 5); + assertThat(result.read(), Matchers.containsInAnyOrder(5)); + multimapState.clear(); + assertThat(multimapState.get(key).read(), Matchers.emptyIterable()); + } + + @Test + public void testMapViaMultimapPutAndGet() { + final String tag = "map"; + StateTag<MapState<byte[], Integer>> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState<byte[], Integer> mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture<Iterable<Integer>> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + mapViaMultiMapState.put(key, 1); + ReadableState<Integer> result = mapViaMultiMapState.get(dup(key)).readLater(); + waitAndSet(future, Collections.singletonList(2), 30); + assertEquals(Integer.valueOf(1), result.read()); + + mapViaMultiMapState.put(key, 3); + assertEquals(Integer.valueOf(3), mapViaMultiMapState.get(key).read()); + mapViaMultiMapState.clear(); + assertNull(mapViaMultiMapState.get(key).read()); } @Test @@ -791,6 +877,33 @@ public class WindmillStateInternalsTest { assertThat(result2.read(), Matchers.emptyIterable()); } + @Test + public void testMapViaMultimapRemoveAndGet() { + final String tag = "map"; + StateTag<MapState<byte[], Integer>> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState<byte[], Integer> mapViaMultiMapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture<Iterable<Integer>> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState<Integer> result1 = mapViaMultiMapState.get(key).readLater(); + ReadableState<Integer> result2 = mapViaMultiMapState.get(dup(key)).readLater(); + waitAndSet(future, Collections.singletonList(1), 30); + + assertEquals(Integer.valueOf(1), result1.read()); + + mapViaMultiMapState.remove(key); + assertNull(mapViaMultiMapState.get(dup(key)).read()); + assertNull(result2.read()); + } + @Test public void testMultimapRemoveThenPut() { final String tag = "multimap"; @@ -1030,6 +1143,64 @@ public class WindmillStateInternalsTest { assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); } + @Test + public void testMapViaMultimapEntriesAndKeysMergeLocalAddRemoveClear() { + final String tag = "map"; + StateTag<MapState<byte[], Integer>> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState<byte[], Integer> mapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8); + + SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult = + mapState.entries().readLater(); + ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater(); + waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 3), multimapEntry(key2, 4)), 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + mapState.put(key1, 7); + mapState.put(dup(key3), 8); + mapState.put(key4, 1); + mapState.remove(key4); + + Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read(); + assertEquals(3, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 7), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key3, 8))); + + Iterable<byte[]> keys = keysResult.read(); + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + assertFalse(mapState.isEmpty().read()); + + mapState.clear(); + assertTrue(mapState.isEmpty().read()); + assertTrue(Iterables.isEmpty(mapState.keys().read())); + assertTrue(Iterables.isEmpty(mapState.entries().read())); + + // Previously read iterable should still have the same result. + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + } + @Test public void testMultimapEntriesAndKeysMergeLocalRemove() { final String tag = "multimap"; @@ -1080,6 +1251,48 @@ public class WindmillStateInternalsTest { assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); } + @Test + public void testMapViaMultimapEntriesAndKeysMergeLocalRemove() { + final String tag = "map"; + StateTag<MapState<byte[], Integer>> addr = + StateTags.map(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MapState<byte[], Integer> mapState = underTestMapViaMultimap.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture<Iterable<Map.Entry<ByteString, Iterable<Integer>>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState<Iterable<Map.Entry<byte[], Integer>>> entriesResult = + mapState.entries().readLater(); + ReadableState<Iterable<byte[]>> keysResult = mapState.keys().readLater(); + waitAndSet(entriesFuture, Arrays.asList(multimapEntry(key1, 1), multimapEntry(key2, 2)), 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + mapState.remove(dup(key1)); + mapState.put(key2, 8); + mapState.put(dup(key3), 9); + + Iterable<Map.Entry<byte[], Integer>> entries = entriesResult.read(); + assertEquals(2, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder(multimapEntryMatcher(key2, 8), multimapEntryMatcher(key3, 9))); + + Iterable<byte[]> keys = keysResult.read(); + assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); + } + @Test public void testMultimapCacheComplete() { final String tag = "multimap"; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java index 175c8421ff8..13019116767 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/DispatchedActiveWorkRefresherTest.java @@ -207,7 +207,7 @@ public class DispatchedActiveWorkRefresherTest { int stuckCommitDurationMillis = 100; Table<ComputationState, ExecutableWork, WindmillStateCache.ForComputation> computations = HashBasedTable.create(); - WindmillStateCache stateCache = WindmillStateCache.ofSizeMbs(100); + WindmillStateCache stateCache = WindmillStateCache.builder().setSizeMb(100).build(); ByteString key = ByteString.EMPTY; for (int i = 0; i < 5; i++) { WindmillStateCache.ForComputation perComputationStateCache = diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java index 942881522cf..df5084ad092 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/StateSpecs.java @@ -377,6 +377,25 @@ public class StateSpecs { } } + /** + * <b><i>For internal use only; no backwards-compatibility guarantees.</i></b> + * + * <p>Convert a set state spec to a map-state spec. + */ + @Internal + public static <KeyT, ValueT> StateSpec<MultimapState<KeyT, ValueT>> convertToMultimapSpecInternal( + StateSpec<MapState<KeyT, ValueT>> spec) { + if (spec instanceof MapStateSpec) { + // Checked above; conversion to a map spec depends on the provided spec being one of those + // created via the factory methods in this class. + @SuppressWarnings("unchecked") + MapStateSpec<KeyT, ValueT> typedSpec = (MapStateSpec<KeyT, ValueT>) spec; + return typedSpec.asMultimapSpec(); + } else { + throw new IllegalArgumentException("Unexpected StateSpec " + spec); + } + } + /** * A specification for a state cell holding a settable value of type {@code T}. * @@ -768,6 +787,10 @@ public class StateSpecs { public int hashCode() { return Objects.hash(getClass(), keyCoder, valueCoder); } + + private MultimapStateSpec<K, V> asMultimapSpec() { + return new MultimapStateSpec<>(this.keyCoder, this.valueCoder); + } } private static class MultimapStateSpec<K, V> implements StateSpec<MultimapState<K, V>> { diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 89dcafbdf94..fb2321328b3 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -2709,19 +2709,26 @@ public class ParDoTest implements Serializable { @StateId(countStateId) CombiningState<Integer, int[], Integer> count, OutputReceiver<KV<String, Integer>> r) { KV<String, Integer> value = element.getValue(); - ReadableState<Iterable<Entry<String, Integer>>> entriesView = state.entries(); state.put(value.getKey(), value.getValue()); count.add(1); + + @Nullable Integer max = state.get("max").read(); + state.put("max", Math.max(max == null ? 0 : max, value.getValue())); if (count.read() >= 4) { - Iterable<Map.Entry<String, Integer>> iterate = state.entries().read(); + assertEquals(Integer.valueOf(97), state.get("a").read()); + + Iterable<Map.Entry<String, Integer>> entriesView = state.entries().read(); + Iterable<String> keysView = state.keys().read(); // Make sure that the cached Iterable doesn't change when new elements are added, // but that cached ReadableState views of the state do change. state.put("BadKey", -1); - assertEquals(3, Iterables.size(iterate)); - assertEquals(4, Iterables.size(entriesView.read())); - assertEquals(4, Iterables.size(state.entries().read())); + assertEquals(4, Iterables.size(entriesView)); + assertEquals(4, Iterables.size(keysView)); + assertEquals(5, Iterables.size(state.entries().read())); + assertEquals(5, Iterables.size(state.keys().read())); + assertEquals(Integer.valueOf(97), state.get("max").read()); - for (Map.Entry<String, Integer> entry : iterate) { + for (Map.Entry<String, Integer> entry : entriesView) { r.output(KV.of(entry.getKey(), entry.getValue())); } } @@ -2732,11 +2739,14 @@ public class ParDoTest implements Serializable { pipeline .apply( Create.of( - KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)), - KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12)))) + KV.of("hello", KV.of("a", 97)), + KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("c", 12)))) .apply(ParDo.of(fn)); - PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12)); + PAssert.that(output) + .containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12), KV.of("max", 97)); pipeline.run(); }