This is an automated email from the ASF dual-hosted git repository. altay pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 4c6a827 Add MapState and SetState support new 20108fd Merge pull request #15238 from kileys/beam-12588-multimapstate 4c6a827 is described below commit 4c6a8271679ec003e195baa30f24541caa0bf9f0 Author: kileys <kiley...@google.com> AuthorDate: Tue Jun 29 04:03:26 2021 +0000 Add MapState and SetState support --- .../dataflow/BatchStatefulParDoOverrides.java | 4 +- .../dataflow/DataflowPipelineTranslator.java | 3 +- .../beam/runners/dataflow/DataflowRunner.java | 22 +- .../beam/runners/dataflow/DataflowRunnerTest.java | 17 + .../java/org/apache/beam/sdk/state/MapState.java | 12 - .../java/org/apache/beam/sdk/state/SetState.java | 5 +- .../beam/fn/harness/state/FnApiStateAccessor.java | 239 ++++++++- .../beam/fn/harness/state/MultimapUserState.java | 292 +++++++++++ .../fn/harness/state/MultimapUserStateTest.java | 557 +++++++++++++++++++++ 9 files changed, 1125 insertions(+), 26 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java index eb16ea2..229fdf6 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/BatchStatefulParDoOverrides.java @@ -161,7 +161,7 @@ public class BatchStatefulParDoOverrides { verifyFnIsStateful(fn); DataflowPipelineOptions options = input.getPipeline().getOptions().as(DataflowPipelineOptions.class); - DataflowRunner.verifyDoFnSupported(fn, false, DataflowRunner.useStreamingEngine(options)); + DataflowRunner.verifyDoFnSupported(fn, false, options); DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< @@ -189,7 +189,7 @@ public class BatchStatefulParDoOverrides { verifyFnIsStateful(fn); DataflowPipelineOptions options = input.getPipeline().getOptions().as(DataflowPipelineOptions.class); - DataflowRunner.verifyDoFnSupported(fn, false, DataflowRunner.useStreamingEngine(options)); + DataflowRunner.verifyDoFnSupported(fn, false, options); DataflowRunner.verifyStateSupportForWindowingStrategy(input.getWindowingStrategy()); PTransform< diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 8229a32..2bf9975 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -1247,8 +1247,7 @@ public class DataflowPipelineTranslator { boolean isStateful = DoFnSignatures.isStateful(fn); if (isStateful) { DataflowPipelineOptions options = context.getPipelineOptions(); - DataflowRunner.verifyDoFnSupported( - fn, options.isStreaming(), DataflowRunner.useStreamingEngine(options)); + DataflowRunner.verifyDoFnSupported(fn, options.isStreaming(), options); DataflowRunner.verifyStateSupportForWindowingStrategy(windowingStrategy); } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index c1220b0..9cb8735 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2293,27 +2293,35 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { || hasExperiment(options, GcpOptions.WINDMILL_SERVICE_EXPERIMENT); } - static void verifyDoFnSupported(DoFn<?, ?> fn, boolean streaming, boolean streamingEngine) { + static void verifyDoFnSupported( + DoFn<?, ?> fn, boolean streaming, DataflowPipelineOptions options) { if (streaming && DoFnSignatures.requiresTimeSortedInput(fn)) { throw new UnsupportedOperationException( String.format( "%s does not currently support @RequiresTimeSortedInput in streaming mode.", DataflowRunner.class.getSimpleName())); } + + boolean streamingEngine = useStreamingEngine(options); + boolean isUnifiedWorker = useUnifiedWorker(options); if (DoFnSignatures.usesSetState(fn)) { - if (streaming && streamingEngine) { + if (streaming && (isUnifiedWorker || streamingEngine)) { throw new UnsupportedOperationException( String.format( - "%s does not currently support %s when using streaming engine", - DataflowRunner.class.getSimpleName(), SetState.class.getSimpleName())); + "%s does not currently support %s when using %s", + DataflowRunner.class.getSimpleName(), + SetState.class.getSimpleName(), + isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); } } if (DoFnSignatures.usesMapState(fn)) { - if (streaming && streamingEngine) { + if (streaming && (isUnifiedWorker || streamingEngine)) { throw new UnsupportedOperationException( String.format( - "%s does not currently support %s when using streaming engine", - DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); + "%s does not currently support %s when using %s", + DataflowRunner.class.getSimpleName(), + MapState.class.getSimpleName(), + isUnifiedWorker ? "streaming on unified worker" : "streaming engine")); } } } diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java index a212fd1..e7df0ba 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/DataflowRunnerTest.java @@ -1539,6 +1539,15 @@ public class DataflowRunnerTest implements Serializable { verifyMapStateUnsupported(options); } + @Test + public void testMapStateUnsupportedStreamingUnifiedRunner() throws Exception { + PipelineOptions options = buildPipelineOptions(); + ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), "use_unified_worker"); + options.as(DataflowPipelineOptions.class).setStreaming(true); + + verifyMapStateUnsupported(options); + } + private void verifySetStateUnsupported(PipelineOptions options) throws Exception { Pipeline p = Pipeline.create(options); p.apply(Create.of(KV.of(13, 42))) @@ -1566,6 +1575,14 @@ public class DataflowRunnerTest implements Serializable { verifySetStateUnsupported(options); } + @Test + public void testSetStateUnsupportedStreamingUnifiedWorker() throws Exception { + PipelineOptions options = buildPipelineOptions(); + ExperimentalOptions.addExperiment(options.as(ExperimentalOptions.class), "use_unified_worker"); + options.as(DataflowPipelineOptions.class).setStreaming(true); + verifySetStateUnsupported(options); + } + /** Records all the composite transforms visited within the Pipeline. */ private static class CompositeTransformRecorder extends PipelineVisitor.Defaults { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java index 6c05ba8..bbbe6cd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/MapState.java @@ -56,12 +56,6 @@ public interface MapState<K, V> extends State { * <p>Changes will not be reflected in the results returned by previous calls to {@link * ReadableState#read} on the results any of the reading methods ({@link #get}, {@link #keys}, * {@link #values}, and {@link #entries}). - * - * <p>Since the condition is not evaluated until {@link ReadableState#read} is called, a call to - * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a read on the - * putIfAbsent return will result in the item being written to the map. Similarly, if there are - * multiple calls to {@link #putIfAbsent} for the same key, precedence will be given to the first - * one on which read is called. */ default ReadableState<V> putIfAbsent(K key, V value) { return computeIfAbsent(key, k -> value); @@ -79,12 +73,6 @@ public interface MapState<K, V> extends State { * <p>Changes will not be reflected in the results returned by previous calls to {@link * ReadableState#read} on the results any of the reading methods ({@link #get}, {@link #keys}, * {@link #values}, and {@link #entries}). - * - * <p>Since the condition is not evaluated until {@link ReadableState#read} is called, a call to - * {@link #putIfAbsent} followed by a call to {@link #remove} followed by a read on the - * putIfAbsent return will result in the item being written to the map. Similarly, if there are - * multiple calls to {@link #putIfAbsent} for the same key, precedence will be given to the first - * one on which read is called. */ ReadableState<V> computeIfAbsent(K key, Function<? super K, ? extends V> mappingFunction); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java index 2ca7226..b4b7bf7 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/state/SetState.java @@ -30,7 +30,10 @@ import org.apache.beam.sdk.annotations.Experimental.Kind; */ @Experimental(Kind.STATE) public interface SetState<T> extends GroupingState<T, Iterable<T>> { - /** Returns true if this set contains the specified element. */ + /** + * Returns a {@link ReadableState} whose {@link #read} method will return true if this set + * contains the specified element at the point when that {@link #read} call returns. + */ ReadableState<Boolean> contains(T t); /** diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index 5a931c5..517be05 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -31,6 +31,7 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; import org.apache.beam.runners.core.SideInputReader; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VoidCoder; import org.apache.beam.sdk.function.ThrowingRunnable; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.BagState; @@ -38,6 +39,7 @@ import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateBinder; import org.apache.beam.sdk.state.StateContext; @@ -54,6 +56,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; @@ -324,7 +327,87 @@ public class FnApiStateAccessor<K> implements SideInputReader, StateBinder { @Override public <T> SetState<T> bindSet(String id, StateSpec<SetState<T>> spec, Coder<T> elemCoder) { - throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + return (SetState<T>) + stateKeyObjectCache.computeIfAbsent( + createMultimapUserStateKey(id), + new Function<StateKey, Object>() { + @Override + public Object apply(StateKey key) { + return new SetState<T>() { + private final MultimapUserState<T, Void> impl = + createMultimapUserState(id, elemCoder, VoidCoder.of()); + + @Override + public void clear() { + impl.clear(); + } + + @Override + public ReadableState<Boolean> contains(T t) { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + return !Iterables.isEmpty(impl.get(t)); + } + + @Override + public ReadableState<Boolean> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public ReadableState<Boolean> addIfAbsent(T t) { + boolean isEmpty = Iterables.isEmpty(impl.get(t)); + if (isEmpty) { + impl.put(t, null); + } + // TODO: Support prefetching. + return ReadableStates.immediate(isEmpty); + } + + @Override + public void remove(T t) { + impl.remove(t); + } + + @Override + public void add(T value) { + impl.remove(value); + impl.put(value, null); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + return Iterables.isEmpty(impl.keys()); + } + + @Override + public ReadableState<Boolean> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public Iterable<T> read() { + return impl.keys(); + } + + @Override + public SetState<T> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + }); } @Override @@ -333,7 +416,133 @@ public class FnApiStateAccessor<K> implements SideInputReader, StateBinder { StateSpec<MapState<KeyT, ValueT>> spec, Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { - throw new UnsupportedOperationException("TODO: Add support for a map state to the Fn API."); + return (MapState<KeyT, ValueT>) + stateKeyObjectCache.computeIfAbsent( + createMultimapUserStateKey(id), + new Function<StateKey, Object>() { + @Override + public Object apply(StateKey key) { + return new MapState<KeyT, ValueT>() { + private final MultimapUserState<KeyT, ValueT> impl = + createMultimapUserState(id, mapKeyCoder, mapValueCoder); + + @Override + public void clear() { + impl.clear(); + } + + @Override + public void put(KeyT key, ValueT value) { + impl.remove(key); + impl.put(key, value); + } + + @Override + public ReadableState<ValueT> computeIfAbsent( + KeyT key, Function<? super KeyT, ? extends ValueT> mappingFunction) { + Iterable<ValueT> values = impl.get(key); + if (Iterables.isEmpty(values)) { + impl.put(key, mappingFunction.apply(key)); + } + return ReadableStates.immediate(Iterables.getOnlyElement(values, null)); + } + + @Override + public void remove(KeyT key) { + impl.remove(key); + } + + @Override + public ReadableState<ValueT> get(KeyT key) { + return getOrDefault(key, null); + } + + @Override + public ReadableState<ValueT> getOrDefault( + KeyT key, @Nullable ValueT defaultValue) { + return new ReadableState<ValueT>() { + @Override + public @Nullable ValueT read() { + Iterable<ValueT> values = impl.get(key); + return Iterables.getOnlyElement(values, defaultValue); + } + + @Override + public ReadableState<ValueT> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public ReadableState<Iterable<KeyT>> keys() { + return new ReadableState<Iterable<KeyT>>() { + @Override + public Iterable<KeyT> read() { + return impl.keys(); + } + + @Override + public ReadableState<Iterable<KeyT>> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public ReadableState<Iterable<ValueT>> values() { + return new ReadableState<Iterable<ValueT>>() { + @Override + public Iterable<ValueT> read() { + return Iterables.transform(entries().read(), e -> e.getValue()); + } + + @Override + public ReadableState<Iterable<ValueT>> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> entries() { + return new ReadableState<Iterable<Map.Entry<KeyT, ValueT>>>() { + @Override + public Iterable<Map.Entry<KeyT, ValueT>> read() { + Iterable<KeyT> keys = keys().read(); + return Iterables.transform( + keys, key -> Maps.immutableEntry(key, get(key).read())); + } + + @Override + public ReadableState<Iterable<Map.Entry<KeyT, ValueT>>> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public Boolean read() { + return Iterables.isEmpty(keys().read()); + } + + @Override + public ReadableState<Boolean> readLater() { + // TODO: Support prefetching. + return this; + } + }; + } + }; + } + }); } @Override @@ -481,6 +690,22 @@ public class FnApiStateAccessor<K> implements SideInputReader, StateBinder { throw new UnsupportedOperationException("WatermarkHoldState is unsupported by the Fn API."); } + private <KeyT, ValueT> MultimapUserState<KeyT, ValueT> createMultimapUserState( + String stateId, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder) { + MultimapUserState<KeyT, ValueT> rval = + new MultimapUserState( + beamFnStateClient, + processBundleInstructionId.get(), + ptransformId, + stateId, + encodedCurrentWindowSupplier.get(), + encodedCurrentKeySupplier.get(), + keyCoder, + valueCoder); + stateFinalizers.add(rval::asyncClose); + return rval; + } + private <T> BagUserState<T> createBagUserState(String stateId, Coder<T> valueCoder) { BagUserState<T> rval = new BagUserState<>( @@ -506,6 +731,16 @@ public class FnApiStateAccessor<K> implements SideInputReader, StateBinder { return builder.build(); } + private StateKey createMultimapUserStateKey(String stateId) { + StateKey.Builder builder = StateKey.newBuilder(); + builder + .getMultimapKeysUserStateBuilder() + .setWindow(encodedCurrentWindowSupplier.get()) + .setTransformId(ptransformId) + .setUserStateId(stateId); + return builder.build(); + } + public void finalizeState() { // Persist all dirty state cells try { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java new file mode 100644 index 0000000..49efa35 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/MultimapUserState.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * An implementation of a multimap user state that utilizes the Beam Fn State API to fetch, clear + * and persist values. + * + * <p>Calling {@link #asyncClose()} schedules any required persistence changes. This object should + * no longer be used after it is closed. + * + * <p>TODO: Move to an async persist model where persistence is signalled based upon cache memory + * pressure and its need to flush. + * + * <p>TODO: Support block level caching and prefetch. + */ +public class MultimapUserState<K, V> { + + private final BeamFnStateClient beamFnStateClient; + private final Coder<K> mapKeyCoder; + private final Coder<V> valueCoder; + private final String stateId; + private final StateRequest keysStateRequest; + private final StateRequest userStateRequest; + + private boolean isClosed; + private boolean isCleared; + // Pending updates to persistent storage + private HashSet<K> pendingRemoves = Sets.newHashSet(); + private HashMap<K, List<V>> pendingAdds = Maps.newHashMap(); + // Map keys with no values in persistent storage + private HashSet<K> negativeCache = Sets.newHashSet(); + // Values retrieved from persistent storage + private Multimap<K, V> persistedValues = ArrayListMultimap.create(); + private @Nullable Iterable<K> persistedKeys = null; + + public MultimapUserState( + BeamFnStateClient beamFnStateClient, + String instructionId, + String pTransformId, + String stateId, + ByteString encodedWindow, + ByteString encodedKey, + Coder<K> mapKeyCoder, + Coder<V> valueCoder) { + this.beamFnStateClient = beamFnStateClient; + this.mapKeyCoder = mapKeyCoder; + this.valueCoder = valueCoder; + this.stateId = stateId; + + StateRequest.Builder keysStateRequestBuilder = StateRequest.newBuilder(); + keysStateRequestBuilder + .setInstructionId(instructionId) + .getStateKeyBuilder() + .getMultimapKeysUserStateBuilder() + .setTransformId(pTransformId) + .setUserStateId(stateId) + .setKey(encodedKey) + .setWindow(encodedWindow); + keysStateRequest = keysStateRequestBuilder.build(); + + StateRequest.Builder userStateRequestBuilder = StateRequest.newBuilder(); + userStateRequestBuilder + .setInstructionId(instructionId) + .getStateKeyBuilder() + .getMultimapUserStateBuilder() + .setTransformId(pTransformId) + .setUserStateId(stateId) + .setWindow(encodedWindow) + .setKey(encodedKey); + userStateRequest = userStateRequestBuilder.build(); + } + + public void clear() { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + + isCleared = true; + persistedValues = ArrayListMultimap.create(); + persistedKeys = null; + pendingRemoves = Sets.newHashSet(); + pendingAdds = Maps.newHashMap(); + negativeCache = Sets.newHashSet(); + } + + /* + * Returns an iterable of the values associated with key in this multimap, if any. + * If there are no values, this returns an empty collection, not null. + */ + public Iterable<V> get(K key) { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + + List<V> pendingAddValues = pendingAdds.getOrDefault(key, Collections.emptyList()); + Collection<V> pendingValues = + Collections.unmodifiableCollection(pendingAddValues.subList(0, pendingAddValues.size())); + if (isCleared || pendingRemoves.contains(key)) { + return pendingValues; + } + + Iterable<V> persistedValues = getPersistedValues(key); + return Iterables.concat(persistedValues, pendingValues); + } + + @SuppressWarnings({ + "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-12687) + }) + /* + * Returns an iterables containing all distinct keys in this multimap. + */ + public Iterable<K> keys() { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + if (isCleared) { + return Collections.unmodifiableCollection(Lists.newArrayList(pendingAdds.keySet())); + } + + Set<K> keys = Sets.newHashSet(getPersistedKeys()); + keys.removeAll(pendingRemoves); + keys.addAll(pendingAdds.keySet()); + return Collections.unmodifiableCollection(keys); + } + + /* + * Store a key-value pair in the multimap. + * Allows duplicate key-value pairs. + */ + public void put(K key, V value) { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + pendingAdds.putIfAbsent(key, new ArrayList<>()); + pendingAdds.get(key).add(value); + } + + /* + * Removes all values for this key in the multimap. + */ + public void remove(K key) { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + pendingAdds.remove(key); + if (!isCleared) { + pendingRemoves.add(key); + } + } + + @SuppressWarnings({ + "nullness" // TODO(https://issues.apache.org/jira/browse/BEAM-12687) + }) + // Update data in persistent store + public void asyncClose() throws Exception { + checkState( + !isClosed, + "Multimap user state is no longer usable because it is closed for %s", + keysStateRequest.getStateKey()); + isClosed = true; + // Nothing to persist + if (!isCleared && pendingRemoves.isEmpty() && pendingAdds.isEmpty()) { + return; + } + + // Clear currently persisted key-values + if (isCleared) { + beamFnStateClient.handle( + keysStateRequest.toBuilder().setClear(StateClearRequest.getDefaultInstance()), + new CompletableFuture<>()); + } else if (!pendingRemoves.isEmpty()) { + for (K key : pendingRemoves) { + beamFnStateClient.handle( + createUserStateRequest(key) + .toBuilder() + .setClear(StateClearRequest.getDefaultInstance()), + new CompletableFuture<>()); + } + } + + // Persist pending key-values + if (!pendingAdds.isEmpty()) { + for (Map.Entry<K, List<V>> entry : pendingAdds.entrySet()) { + beamFnStateClient.handle( + createUserStateRequest(entry.getKey()) + .toBuilder() + .setAppend(StateAppendRequest.newBuilder().setData(encodeValues(entry.getValue()))), + new CompletableFuture<>()); + } + } + } + + private ByteString encodeValues(Iterable<V> values) { + try { + ByteString.Output output = ByteString.newOutput(); + for (V value : values) { + valueCoder.encode(value, output); + } + return output.toByteString(); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to encode values for multimap user state id %s.", stateId), e); + } + } + + private StateRequest createUserStateRequest(K key) { + try { + ByteString.Output output = ByteString.newOutput(); + mapKeyCoder.encode(key, output); + StateRequest.Builder request = userStateRequest.toBuilder(); + request.getStateKeyBuilder().getMultimapUserStateBuilder().setMapKey(output.toByteString()); + return request.build(); + } catch (IOException e) { + throw new IllegalStateException( + String.format("Failed to encode key for multimap user state id %s.", stateId), e); + } + } + + private Iterable<V> getPersistedValues(K key) { + if (negativeCache.contains(key)) { + return Collections.emptyList(); + } + + if (persistedValues.get(key).isEmpty()) { + Iterable<V> values = + StateFetchingIterators.readAllAndDecodeStartingFrom( + beamFnStateClient, createUserStateRequest(key), valueCoder); + if (Iterables.isEmpty(values)) { + negativeCache.add(key); + } + persistedValues.putAll(key, values); + } + return Iterables.unmodifiableIterable(persistedValues.get(key)); + } + + private Iterable<K> getPersistedKeys() { + checkState(!isCleared); + if (persistedKeys == null) { + Iterable<K> keys = + StateFetchingIterators.readAllAndDecodeStartingFrom( + beamFnStateClient, keysStateRequest, mapKeyCoder); + persistedKeys = Iterables.unmodifiableIterable(keys); + } + return persistedKeys; + } +} diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java new file mode 100644 index 0000000..23f1be4 --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/MultimapUserStateTest.java @@ -0,0 +1,557 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MultimapUserStateTest { + + private final String pTransformId = "pTransformId"; + private final String stateId = "stateId"; + private final String encodedKey = "encodedKey"; + private final String encodedWindow = "encodedWindow"; + + @Test + public void testNoPersistedValues() throws Exception { + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(Collections.emptyMap()); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + assertThat(userState.keys(), is(emptyIterable())); + } + + @Test + public void testGet() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + + Iterable<String> initValues = userState.get("A1"); + userState.put("A1", "V3"); + assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, String.class)); + assertArrayEquals( + new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get("A1"), String.class)); + assertArrayEquals(new String[] {}, Iterables.toArray(userState.get("A2"), String.class)); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.get("A1")); + } + + @Test + public void testClear() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + + Iterable<String> initValues = userState.get("A1"); + userState.clear(); + assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, String.class)); + assertThat(userState.get("A1"), is(emptyIterable())); + assertThat(userState.keys(), is(emptyIterable())); + + userState.put("A1", "V1"); + userState.clear(); + assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, String.class)); + assertThat(userState.get("A1"), is(emptyIterable())); + assertThat(userState.keys(), is(emptyIterable())); + + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.clear()); + } + + @Test + public void testKeys() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + + userState.put("A2", "V1"); + Iterable<String> initKeys = userState.keys(); + userState.put("A3", "V1"); + userState.put("A1", "V3"); + assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, String.class)); + assertArrayEquals( + new String[] {"A1", "A2", "A3"}, Iterables.toArray(userState.keys(), String.class)); + + userState.clear(); + assertArrayEquals(new String[] {"A1", "A2"}, Iterables.toArray(initKeys, String.class)); + assertArrayEquals(new String[] {}, Iterables.toArray(userState.keys(), String.class)); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.keys()); + } + + @Test + public void testPut() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + + Iterable<String> initValues = userState.get("A1"); + userState.put("A1", "V3"); + assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, String.class)); + assertArrayEquals( + new String[] {"V1", "V2", "V3"}, Iterables.toArray(userState.get("A1"), String.class)); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.put("A1", "V2")); + } + + @Test + public void testPutAfterRemove() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A0"), + createMultimapValueStateKey("A0"), + encode("V1"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.remove("A0"); + userState.put("A0", "V2"); + assertArrayEquals(new String[] {"V2"}, Iterables.toArray(userState.get("A0"), String.class)); + userState.asyncClose(); + Map<StateKey, ByteString> data = fakeClient.getData(); + assertEquals(encode("V2"), data.get(createMultimapValueStateKey("A0"))); + } + + @Test + public void testPutAfterClear() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A0"), + createMultimapValueStateKey("A0"), + encode("V1"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.clear(); + userState.put("A0", "V2"); + assertArrayEquals(new String[] {"V2"}, Iterables.toArray(userState.get("A0"), String.class)); + } + + @Test + public void testRemoveBeforeClear() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A0"), + createMultimapValueStateKey("A0"), + encode("V1"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.remove("A0"); + userState.clear(); + userState.asyncClose(); + // Clear takes precedence over specific key remove + assertThat(fakeClient.getCallCount(), is(1)); + } + + @Test + public void testPutBeforeClear() throws Exception { + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(Collections.emptyMap()); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.put("A0", "V0"); + userState.put("A1", "V1"); + Iterable<String> values = userState.get("A1"); // fakeClient call = 1 + userState.clear(); // fakeClient call = 2 + assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, String.class)); + userState.asyncClose(); + // Clear takes precedence over puts + assertThat(fakeClient.getCallCount(), is(2)); + } + + @Test + public void testPutBeforeRemove() throws Exception { + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(Collections.emptyMap()); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.put("A0", "V0"); + userState.put("A1", "V1"); + Iterable<String> values = userState.get("A1"); // fakeClient call = 1 + userState.remove("A0"); // fakeClient call = 2 + userState.remove("A1"); // fakeClient call = 3 + assertArrayEquals(new String[] {"V1"}, Iterables.toArray(values, String.class)); + userState.asyncClose(); + assertThat(fakeClient.getCallCount(), is(3)); + assertNull(fakeClient.getData().get(createMultimapValueStateKey("A0"))); + assertNull(fakeClient.getData().get(createMultimapValueStateKey("A1"))); + } + + @Test + public void testRemove() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + + Iterable<String> initValues = userState.get("A1"); + userState.put("A1", "V3"); + + userState.remove("A1"); + assertArrayEquals(new String[] {"V1", "V2"}, Iterables.toArray(initValues, String.class)); + assertThat(userState.keys(), is(emptyIterable())); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.remove("A1")); + } + + @Test + public void testImmutableKeys() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + Iterable<String> keys = userState.keys(); + assertThrows( + UnsupportedOperationException.class, () -> Iterables.removeAll(keys, Arrays.asList("A1"))); + } + + @Test + public void testImmutableValues() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + Iterable<String> values = userState.get("A1"); + assertThrows( + UnsupportedOperationException.class, + () -> Iterables.removeAll(values, Arrays.asList("V1"))); + } + + @Test + public void testClearAsyncClose() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.clear(); + userState.asyncClose(); + Map<StateKey, ByteString> data = fakeClient.getData(); + assertEquals(1, data.size()); + assertNull(data.get(createMultimapKeyStateKey())); + } + + @Test + public void testNoopAsyncClose() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.keys()); + assertEquals(0, fakeClient.getCallCount()); + } + + @Test + public void testAsyncClose() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A0", "A1"), + createMultimapValueStateKey("A0"), + encode("V1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.remove("A0"); + userState.put("A1", "V3"); + userState.put("A2", "V1"); + userState.put("A3", "V1"); + userState.remove("A3"); + userState.asyncClose(); + Map<StateKey, ByteString> data = fakeClient.getData(); + assertNull(data.get(createMultimapValueStateKey("A0"))); + assertEquals(encode("V1", "V2", "V3"), data.get(createMultimapValueStateKey("A1"))); + assertEquals(encode("V1"), data.get(createMultimapValueStateKey("A2"))); + } + + @Test + public void testNullKeysAndValues() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + ImmutableMap.of( + createMultimapKeyStateKey(), + encode("A1"), + createMultimapValueStateKey("A1"), + encode("V1", "V2"))); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + NullableCoder.of(StringUtf8Coder.of()), + NullableCoder.of(StringUtf8Coder.of())); + userState.put(null, null); + userState.put(null, null); + userState.put(null, "V1"); + assertArrayEquals( + new String[] {null, null, "V1"}, Iterables.toArray(userState.get(null), String.class)); + } + + @Test + public void testNegativeCache() throws Exception { + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(Collections.emptyMap()); + MultimapUserState<String, String> userState = + new MultimapUserState<>( + fakeClient, + "instructionId", + pTransformId, + stateId, + encode(encodedWindow), + encode(encodedKey), + StringUtf8Coder.of(), + StringUtf8Coder.of()); + userState.get("A1"); + userState.get("A1"); + assertThat(fakeClient.getCallCount(), is(1)); + } + + private StateKey createMultimapKeyStateKey() throws IOException { + return StateKey.newBuilder() + .setMultimapKeysUserState( + StateKey.MultimapKeysUserState.newBuilder() + .setWindow(encode(encodedWindow)) + .setKey(encode(encodedKey)) + .setTransformId(pTransformId) + .setUserStateId(stateId)) + .build(); + } + + private StateKey createMultimapValueStateKey(String key) throws IOException { + return StateKey.newBuilder() + .setMultimapUserState( + StateKey.MultimapUserState.newBuilder() + .setTransformId(pTransformId) + .setUserStateId(stateId) + .setWindow(encode(encodedWindow)) + .setKey(encode(encodedKey)) + .setMapKey(encode(key))) + .build(); + } + + private ByteString encode(String... values) throws IOException { + ByteString.Output out = ByteString.newOutput(); + for (String value : values) { + StringUtf8Coder.of().encode(value, out); + } + return out.toByteString(); + } +}