Add SetState and MapState
Project: http://git-wip-us.apache.org/repos/asf/beam/repo Commit: http://git-wip-us.apache.org/repos/asf/beam/commit/a0702f5b Tree: http://git-wip-us.apache.org/repos/asf/beam/tree/a0702f5b Diff: http://git-wip-us.apache.org/repos/asf/beam/diff/a0702f5b Branch: refs/heads/master Commit: a0702f5bed3c7269e90b4702266945aa34dd1aea Parents: 0f48321 Author: JingsongLi <lzljs3620...@aliyun.com> Authored: Tue Feb 14 14:52:05 2017 +0800 Committer: Kenneth Knowles <k...@google.com> Committed: Tue Feb 14 11:06:29 2017 -0800 ---------------------------------------------------------------------- .../translation/utils/ApexStateInternals.java | 18 ++ .../runners/core/InMemoryStateInternals.java | 205 ++++++++++++++ .../apache/beam/runners/core/StateMerging.java | 44 +++ .../org/apache/beam/runners/core/StateTag.java | 8 + .../org/apache/beam/runners/core/StateTags.java | 30 ++ .../core/InMemoryStateInternalsTest.java | 280 +++++++++++++++++-- .../apache/beam/runners/core/StateTagTest.java | 33 +++ .../CopyOnAccessInMemoryStateInternals.java | 46 +++ .../CopyOnAccessInMemoryStateInternalsTest.java | 58 ++++ .../wrappers/streaming/FlinkStateInternals.java | 18 ++ .../apache/beam/sdk/util/state/MapState.java | 93 ++++++ .../apache/beam/sdk/util/state/SetState.java | 71 +++++ .../apache/beam/sdk/util/state/StateBinder.java | 6 + .../apache/beam/sdk/util/state/StateSpecs.java | 89 ++++++ .../apache/beam/sdk/transforms/ParDoTest.java | 94 +++++++ 15 files changed, 1063 insertions(+), 30 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java index 34d993f..7634366 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/utils/ApexStateInternals.java @@ -45,7 +45,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateContext; import org.apache.beam.sdk.util.state.StateContexts; @@ -121,6 +123,22 @@ public class ApexStateInternals<K> implements StateInternals<K>, Serializable { } @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, + Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java index 6a181f3..b4b2b38 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/InMemoryStateInternals.java @@ -17,10 +17,16 @@ */ package org.apache.beam.runners.core; +import static com.google.common.base.Preconditions.checkNotNull; + import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Objects; +import java.util.Set; import javax.annotation.Nullable; import org.apache.beam.runners.core.StateTag.StateBinder; import org.apache.beam.sdk.annotations.Experimental; @@ -34,7 +40,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateContext; import org.apache.beam.sdk.util.state.StateContexts; @@ -128,6 +136,18 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { } @Override + public <T> SetState<T> bindSet(StateTag<? super K, SetState<T>> spec, Coder<T> elemCoder) { + return new InMemorySet<>(); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + return new InMemoryMap<>(); + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, @@ -435,4 +455,189 @@ public class InMemoryStateInternals<K> implements StateInternals<K> { return that; } } + + /** + * An {@link InMemoryState} implementation of {@link SetState}. + */ + public static final class InMemorySet<T> implements SetState<T>, InMemoryState<InMemorySet<T>> { + private Set<T> contents = new HashSet<>(); + + @Override + public void clear() { + contents = new HashSet<>(); + } + + @Override + public boolean contains(T t) { + return contents.contains(t); + } + + @Override + public boolean addIfAbsent(T t) { + return contents.add(t); + } + + @Override + public void remove(T t) { + contents.remove(t); + } + + @Override + public SetState<T> readLater(Iterable<T> elements) { + return this; + } + + @Override + public boolean containsAny(Iterable<T> elements) { + elements = checkNotNull(elements); + for (T t : elements) { + if (contents.contains(t)) { + return true; + } + } + return false; + } + + @Override + public boolean containsAll(Iterable<T> elements) { + elements = checkNotNull(elements); + for (T t : elements) { + if (!contents.contains(t)) { + return false; + } + } + return true; + } + + @Override + public InMemorySet<T> readLater() { + return this; + } + + @Override + public Iterable<T> read() { + return contents; + } + + @Override + public void add(T input) { + contents.add(input); + } + + @Override + public boolean isCleared() { + return contents.isEmpty(); + } + + @Override + public ReadableState<Boolean> isEmpty() { + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + return this; + } + + @Override + public Boolean read() { + return contents.isEmpty(); + } + }; + } + + @Override + public InMemorySet<T> copy() { + InMemorySet<T> that = new InMemorySet<>(); + that.contents.addAll(this.contents); + return that; + } + } + + /** + * An {@link InMemoryState} implementation of {@link MapState}. + */ + public static final class InMemoryMap<K, V> implements + MapState<K, V>, InMemoryState<InMemoryMap<K, V>> { + private Map<K, V> contents = new HashMap<>(); + + @Override + public void clear() { + contents = new HashMap<>(); + } + + @Override + public V get(K key) { + return contents.get(key); + } + + @Override + public void put(K key, V value) { + contents.put(key, value); + } + + @Override + public V putIfAbsent(K key, V value) { + V v = contents.get(key); + if (v == null) { + v = contents.put(key, value); + } + + return v; + } + + @Override + public void remove(K key) { + contents.remove(key); + } + + @Override + public Iterable<V> get(Iterable<K> keys) { + List<V> values = new ArrayList<>(); + for (K k : keys) { + values.add(contents.get(k)); + } + return values; + } + + @Override + public MapState<K, V> getLater(K k) { + return this; + } + + @Override + public MapState<K, V> getLater(Iterable<K> keys) { + return this; + } + + @Override + public Iterable<K> keys() { + return contents.keySet(); + } + + @Override + public Iterable<V> values() { + return contents.values(); + } + + @Override + public MapState<K, V> iterateLater() { + return this; + } + + @Override + public Iterable<Map.Entry<K, V>> iterate() { + return contents.entrySet(); + } + + @Override + public boolean isCleared() { + return contents.isEmpty(); + } + + @Override + public InMemoryMap<K, V> copy() { + InMemoryMap<K, V> that = new InMemoryMap<>(); + that.contents.putAll(this.contents); + return that; + } + } } http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java index c533f83..e98d098 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateMerging.java @@ -28,6 +28,7 @@ import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; import org.apache.beam.sdk.util.state.CombiningState; import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.WatermarkHoldState; import org.joda.time.Instant; @@ -112,6 +113,49 @@ public class StateMerging { } /** + * Merge all set state in {@code address} across all windows under merge. + */ + public static <K, T, W extends BoundedWindow> void mergeSets( + MergingStateAccessor<K, W> context, StateTag<? super K, SetState<T>> address) { + mergeSets(context.accessInEachMergingWindow(address).values(), context.access(address)); + } + + /** + * Merge all set state in {@code sources} (which may include {@code result}) into {@code result}. + */ + public static <T, W extends BoundedWindow> void mergeSets( + Collection<SetState<T>> sources, SetState<T> result) { + if (sources.isEmpty()) { + // Nothing to merge. + return; + } + // Prefetch everything except what's already in result. + List<ReadableState<Iterable<T>>> futures = new ArrayList<>(sources.size()); + for (SetState<T> source : sources) { + if (!source.equals(result)) { + prefetchRead(source); + futures.add(source); + } + } + if (futures.isEmpty()) { + // Result already holds all the values. + return; + } + // Transfer from sources to result. + for (ReadableState<Iterable<T>> future : futures) { + for (T element : future.read()) { + result.add(element); + } + } + // Clear sources except for result. + for (SetState<T> source : sources) { + if (!source.equals(result)) { + source.clear(); + } + } + } + + /** * Prefetch all combining value state for {@code address} across all merging windows in {@code * context}. */ http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java index a3d703f..802aede 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTag.java @@ -30,6 +30,8 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.util.state.ValueState; @@ -86,6 +88,12 @@ public interface StateTag<K, StateT extends State> extends Serializable { <T> BagState<T> bindBag(StateTag<? super K, BagState<T>> spec, Coder<T> elemCoder); + <T> SetState<T> bindSet(StateTag<? super K, SetState<T>> spec, Coder<T> elemCoder); + + <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder); + <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> spec, Coder<AccumT> accumCoder, http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java index cf7c236..1c70dff 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/StateTags.java @@ -32,6 +32,8 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateBinder; import org.apache.beam.sdk.util.state.StateSpec; @@ -68,6 +70,19 @@ public class StateTags { } @Override + public <T> SetState<T> bindSet( + String id, StateSpec<? super K, SetState<T>> spec, Coder<T> elemCoder) { + return binder.bindSet(tagForSpec(id, spec), elemCoder); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + String id, StateSpec<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + return binder.bindMap(tagForSpec(id, spec), mapKeyCoder, mapValueCoder); + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( String id, @@ -200,6 +215,21 @@ public class StateTags { } /** + * Create a state spec that supporting for {@link java.util.Set} like access patterns. + */ + public static <T> StateTag<Object, SetState<T>> set(String id, Coder<T> elemCoder) { + return new SimpleStateTag<>(new StructuredId(id), StateSpecs.set(elemCoder)); + } + + /** + * Create a state spec that supporting for {@link java.util.Map} like access patterns. + */ + public static <K, V> StateTag<Object, MapState<K, V>> map( + String id, Coder<K> keyCoder, Coder<V> valueCoder) { + return new SimpleStateTag<>(new StructuredId(id), StateSpecs.map(keyCoder, valueCoder)); + } + + /** * Create a state tag for holding the watermark. */ public static <W extends BoundedWindow> StateTag<Object, WatermarkHoldState<W>> http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java index 8ea9abc..1da946f 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/InMemoryStateInternalsTest.java @@ -17,11 +17,18 @@ */ package org.apache.beam.runners.core; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; +import java.util.Objects; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Sum; @@ -31,7 +38,9 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFns; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.MapState; import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.ValueState; import org.apache.beam.sdk.util.state.WatermarkHoldState; import org.hamcrest.Matchers; @@ -57,6 +66,10 @@ public class InMemoryStateInternalsTest { "sumInteger", VarIntCoder.of(), Sum.ofIntegers()); private static final StateTag<Object, BagState<String>> STRING_BAG_ADDR = StateTags.bag("stringBag", StringUtf8Coder.of()); + private static final StateTag<Object, SetState<String>> STRING_SET_ADDR = + StateTags.set("stringSet", StringUtf8Coder.of()); + private static final StateTag<Object, MapState<String, Integer>> STRING_MAP_ADDR = + StateTags.map("stringMap", StringUtf8Coder.of(), VarIntCoder.of()); private static final StateTag<Object, WatermarkHoldState<BoundedWindow>> WATERMARK_EARLIEST_ADDR = StateTags.watermarkStateInternal("watermark", OutputTimeFns.outputAtEarliestInputTimestamp()); @@ -80,9 +93,9 @@ public class InMemoryStateInternalsTest { assertThat(value.read(), Matchers.nullValue()); value.write("hello"); - assertThat(value.read(), Matchers.equalTo("hello")); + assertThat(value.read(), equalTo("hello")); value.write("world"); - assertThat(value.read(), Matchers.equalTo("world")); + assertThat(value.read(), equalTo("world")); value.clear(); assertThat(value.read(), Matchers.nullValue()); @@ -94,8 +107,8 @@ public class InMemoryStateInternalsTest { BagState<String> value = underTest.state(NAMESPACE_1, STRING_BAG_ADDR); // State instances are cached, but depend on the namespace. - assertEquals(value, underTest.state(NAMESPACE_1, STRING_BAG_ADDR)); - assertFalse(value.equals(underTest.state(NAMESPACE_2, STRING_BAG_ADDR))); + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_BAG_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_BAG_ADDR)))); assertThat(value.read(), Matchers.emptyIterable()); value.add("hello"); @@ -157,6 +170,213 @@ public class InMemoryStateInternalsTest { } @Test + public void testSet() throws Exception { + SetState<String> value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_SET_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_SET_ADDR)))); + + // empty + assertThat(value.read(), Matchers.emptyIterable()); + assertFalse(value.contains("A")); + assertFalse(value.containsAny(Collections.singletonList("A"))); + + // add + value.add("A"); + value.add("B"); + value.add("A"); + assertFalse(value.addIfAbsent("B")); + assertThat(value.read(), Matchers.containsInAnyOrder("A", "B")); + + // remove + value.remove("A"); + assertThat(value.read(), Matchers.containsInAnyOrder("B")); + value.remove("C"); + assertThat(value.read(), Matchers.containsInAnyOrder("B")); + + // contains + assertFalse(value.contains("A")); + assertTrue(value.contains("B")); + value.add("C"); + value.add("D"); + + // containsAny + assertTrue(value.containsAny(Arrays.asList("A", "C"))); + assertFalse(value.containsAny(Arrays.asList("A", "E"))); + + // containsAll + assertTrue(value.containsAll(Arrays.asList("B", "C"))); + assertFalse(value.containsAll(Arrays.asList("A", "B"))); + + // readLater + assertThat(value.readLater().read(), Matchers.containsInAnyOrder("B", "C", "D")); + SetState<String> later = value.readLater(Arrays.asList("A", "C", "D")); + assertTrue(later.containsAll(Arrays.asList("C", "D"))); + assertFalse(later.contains("A")); + + // clear + value.clear(); + assertThat(value.read(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_SET_ADDR), Matchers.sameInstance(value)); + + } + + @Test + public void testSetIsEmpty() throws Exception { + SetState<String> value = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + + assertThat(value.isEmpty().read(), Matchers.is(true)); + ReadableState<Boolean> readFuture = value.isEmpty(); + value.add("hello"); + assertThat(readFuture.read(), Matchers.is(false)); + + value.clear(); + assertThat(readFuture.read(), Matchers.is(true)); + } + + @Test + public void testMergeSetIntoSource() throws Exception { + SetState<String> set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + SetState<String> set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); + + set1.add("Hello"); + set2.add("Hello"); + set2.add("World"); + set1.add("!"); + + StateMerging.mergeSets(Arrays.asList(set1, set2), set1); + + // Reading the merged set gets both the contents + assertThat(set1.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(set2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMergeSetIntoNewNamespace() throws Exception { + SetState<String> set1 = underTest.state(NAMESPACE_1, STRING_SET_ADDR); + SetState<String> set2 = underTest.state(NAMESPACE_2, STRING_SET_ADDR); + SetState<String> set3 = underTest.state(NAMESPACE_3, STRING_SET_ADDR); + + set1.add("Hello"); + set2.add("Hello"); + set2.add("World"); + set1.add("!"); + + StateMerging.mergeSets(Arrays.asList(set1, set2, set3), set3); + + // Reading the merged set gets both the contents + assertThat(set3.read(), Matchers.containsInAnyOrder("Hello", "World", "!")); + assertThat(set1.read(), Matchers.emptyIterable()); + assertThat(set2.read(), Matchers.emptyIterable()); + } + + // for testMap + private static class MapEntry<K, V> implements Map.Entry<K, V> { + private K key; + private V value; + + private MapEntry(K key, V value) { + this.key = key; + this.value = value; + } + + static <K, V> Map.Entry<K, V> of(K k, V v) { + return new MapEntry<>(k, v); + } + + public final K getKey() { + return key; + } + public final V getValue() { + return value; + } + + public final String toString() { + return key + "=" + value; + } + + public final int hashCode() { + return Objects.hashCode(key) ^ Objects.hashCode(value); + } + + public final V setValue(V newValue) { + V oldValue = value; + value = newValue; + return oldValue; + } + + public final boolean equals(Object o) { + if (o == this) { + return true; + } + if (o instanceof Map.Entry) { + Map.Entry<?, ?> e = (Map.Entry<?, ?>) o; + if (Objects.equals(key, e.getKey()) + && Objects.equals(value, e.getValue())) { + return true; + } + } + return false; + } + } + + @Test + public void testMap() throws Exception { + MapState<String, Integer> value = underTest.state(NAMESPACE_1, STRING_MAP_ADDR); + + // State instances are cached, but depend on the namespace. + assertThat(value, equalTo(underTest.state(NAMESPACE_1, STRING_MAP_ADDR))); + assertThat(value, not(equalTo(underTest.state(NAMESPACE_2, STRING_MAP_ADDR)))); + + // put + assertThat(value.iterate(), Matchers.emptyIterable()); + value.put("A", 1); + value.put("B", 2); + value.put("A", 11); + assertThat(value.putIfAbsent("B", 22), equalTo(2)); + assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("A", 11), + MapEntry.of("B", 2))); + + // remove + value.remove("A"); + assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("B", 2))); + value.remove("C"); + assertThat(value.iterate(), Matchers.containsInAnyOrder(MapEntry.of("B", 2))); + + // get + assertNull(value.get("A")); + assertThat(value.get("B"), equalTo(2)); + value.put("C", 3); + value.put("D", 4); + assertThat(value.get("C"), equalTo(3)); + assertThat(value.get(Collections.singletonList("D")), Matchers.containsInAnyOrder(4)); + assertThat(value.get(Arrays.asList("B", "C")), Matchers.containsInAnyOrder(2, 3)); + + // iterate + value.put("E", 5); + value.remove("C"); + assertThat(value.keys(), Matchers.containsInAnyOrder("B", "D", "E")); + assertThat(value.values(), Matchers.containsInAnyOrder(2, 4, 5)); + assertThat(value.iterate(), Matchers.containsInAnyOrder( + MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); + + // readLater + assertThat(value.getLater("B").get("B"), equalTo(2)); + assertNull(value.getLater("A").get("A")); + MapState<String, Integer> later = value.getLater(Arrays.asList("C", "D")); + assertNull(later.get("C")); + assertThat(later.get("D"), equalTo(4)); + assertThat(value.iterateLater().iterate(), Matchers.containsInAnyOrder( + MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); + + // clear + value.clear(); + assertThat(value.iterate(), Matchers.emptyIterable()); + assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), Matchers.sameInstance(value)); + } + + @Test public void testCombiningValue() throws Exception { CombiningState<Integer, Integer> value = underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR); @@ -164,15 +384,15 @@ public class InMemoryStateInternalsTest { assertEquals(value, underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR)); assertFalse(value.equals(underTest.state(NAMESPACE_2, SUM_INTEGER_ADDR))); - assertThat(value.read(), Matchers.equalTo(0)); + assertThat(value.read(), equalTo(0)); value.add(2); - assertThat(value.read(), Matchers.equalTo(2)); + assertThat(value.read(), equalTo(2)); value.add(3); - assertThat(value.read(), Matchers.equalTo(5)); + assertThat(value.read(), equalTo(5)); value.clear(); - assertThat(value.read(), Matchers.equalTo(0)); + assertThat(value.read(), equalTo(0)); assertThat(underTest.state(NAMESPACE_1, SUM_INTEGER_ADDR), Matchers.sameInstance(value)); } @@ -200,14 +420,14 @@ public class InMemoryStateInternalsTest { value2.add(10); value1.add(6); - assertThat(value1.read(), Matchers.equalTo(11)); - assertThat(value2.read(), Matchers.equalTo(10)); + assertThat(value1.read(), equalTo(11)); + assertThat(value2.read(), equalTo(10)); // Merging clears the old values and updates the result value. StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value1); - assertThat(value1.read(), Matchers.equalTo(21)); - assertThat(value2.read(), Matchers.equalTo(0)); + assertThat(value1.read(), equalTo(21)); + assertThat(value2.read(), equalTo(0)); } @Test @@ -226,9 +446,9 @@ public class InMemoryStateInternalsTest { StateMerging.mergeCombiningValues(Arrays.asList(value1, value2), value3); // Merging clears the old values and updates the result value. - assertThat(value1.read(), Matchers.equalTo(0)); - assertThat(value2.read(), Matchers.equalTo(0)); - assertThat(value3.read(), Matchers.equalTo(21)); + assertThat(value1.read(), equalTo(0)); + assertThat(value2.read(), equalTo(0)); + assertThat(value3.read(), equalTo(21)); } @Test @@ -242,16 +462,16 @@ public class InMemoryStateInternalsTest { assertThat(value.read(), Matchers.nullValue()); value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value.read(), equalTo(new Instant(2000))); value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value.read(), equalTo(new Instant(2000))); value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(1000))); + assertThat(value.read(), equalTo(new Instant(1000))); value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); + assertThat(value.read(), equalTo(null)); assertThat(underTest.state(NAMESPACE_1, WATERMARK_EARLIEST_ADDR), Matchers.sameInstance(value)); } @@ -266,16 +486,16 @@ public class InMemoryStateInternalsTest { assertThat(value.read(), Matchers.nullValue()); value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value.read(), equalTo(new Instant(2000))); value.add(new Instant(3000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + assertThat(value.read(), equalTo(new Instant(3000))); value.add(new Instant(1000)); - assertThat(value.read(), Matchers.equalTo(new Instant(3000))); + assertThat(value.read(), equalTo(new Instant(3000))); value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); + assertThat(value.read(), equalTo(null)); assertThat(underTest.state(NAMESPACE_1, WATERMARK_LATEST_ADDR), Matchers.sameInstance(value)); } @@ -289,10 +509,10 @@ public class InMemoryStateInternalsTest { assertThat(value.read(), Matchers.nullValue()); value.add(new Instant(2000)); - assertThat(value.read(), Matchers.equalTo(new Instant(2000))); + assertThat(value.read(), equalTo(new Instant(2000))); value.clear(); - assertThat(value.read(), Matchers.equalTo(null)); + assertThat(value.read(), equalTo(null)); assertThat(underTest.state(NAMESPACE_1, WATERMARK_EOW_ADDR), Matchers.sameInstance(value)); } @@ -325,8 +545,8 @@ public class InMemoryStateInternalsTest { // Merging clears the old values and updates the merged value. StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value1, WINDOW_1); - assertThat(value1.read(), Matchers.equalTo(new Instant(2000))); - assertThat(value2.read(), Matchers.equalTo(null)); + assertThat(value1.read(), equalTo(new Instant(2000))); + assertThat(value2.read(), equalTo(null)); } @Test @@ -347,8 +567,8 @@ public class InMemoryStateInternalsTest { StateMerging.mergeWatermarks(Arrays.asList(value1, value2), value3, WINDOW_1); // Merging clears the old values and updates the result value. - assertThat(value3.read(), Matchers.equalTo(new Instant(5000))); - assertThat(value1.read(), Matchers.equalTo(null)); - assertThat(value2.read(), Matchers.equalTo(null)); + assertThat(value3.read(), equalTo(new Instant(5000))); + assertThat(value1.read(), equalTo(null)); + assertThat(value2.read(), equalTo(null)); } } http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java ---------------------------------------------------------------------- diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java index 9a04628..0584643 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateTagTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertNotEquals; import org.apache.beam.sdk.coders.BigEndianIntegerCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Max; @@ -63,6 +64,38 @@ public class StateTagTest { } @Test + public void testSetEquality() { + StateTag<?, ?> fooVarInt1 = StateTags.set("foo", VarIntCoder.of()); + StateTag<?, ?> fooVarInt2 = StateTags.set("foo", VarIntCoder.of()); + StateTag<?, ?> fooBigEndian = StateTags.set("foo", BigEndianIntegerCoder.of()); + StateTag<?, ?> barVarInt = StateTags.set("bar", VarIntCoder.of()); + + assertEquals(fooVarInt1, fooVarInt2); + assertNotEquals(fooVarInt1, fooBigEndian); + assertNotEquals(fooVarInt1, barVarInt); + } + + @Test + public void testMapEquality() { + StateTag<?, ?> fooStringVarInt1 = + StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of()); + StateTag<?, ?> fooStringVarInt2 = + StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of()); + StateTag<?, ?> fooStringBigEndian = + StateTags.map("foo", StringUtf8Coder.of(), BigEndianIntegerCoder.of()); + StateTag<?, ?> fooVarIntBigEndian = + StateTags.map("foo", VarIntCoder.of(), BigEndianIntegerCoder.of()); + StateTag<?, ?> barStringVarInt = + StateTags.map("bar", StringUtf8Coder.of(), VarIntCoder.of()); + + assertEquals(fooStringVarInt1, fooStringVarInt2); + assertNotEquals(fooStringVarInt1, fooStringBigEndian); + assertNotEquals(fooStringBigEndian, fooVarIntBigEndian); + assertNotEquals(fooStringVarInt1, fooVarIntBigEndian); + assertNotEquals(fooStringVarInt1, barStringVarInt); + } + + @Test public void testWatermarkBagEquality() { StateTag<?, ?> foo1 = StateTags.watermarkStateInternal( "foo", OutputTimeFns.outputAtEarliestInputTimestamp()); http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java index 47c0251..ff5c23c 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternals.java @@ -27,6 +27,8 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryBag; import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryCombiningValue; +import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryMap; +import org.apache.beam.runners.core.InMemoryStateInternals.InMemorySet; import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryState; import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryStateBinder; import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryValue; @@ -45,6 +47,8 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFn; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateContext; import org.apache.beam.sdk.util.state.StateContexts; @@ -334,6 +338,35 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> } @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState<? extends SetState<T>> existingState = + (InMemoryState<? extends SetState<T>>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemorySet<>(); + } + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> address, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + if (containedInUnderlying(namespace, address)) { + @SuppressWarnings("unchecked") + InMemoryState<? extends MapState<KeyT, ValueT>> existingState = + (InMemoryState<? extends MapState<KeyT, ValueT>>) + underlying.get().get(namespace, address, c); + return existingState.copy(); + } else { + return new InMemoryMap<>(); + } + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, @@ -430,6 +463,19 @@ public class CopyOnAccessInMemoryStateInternals<K> implements StateInternals<K> } @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, Coder<T> elemCoder) { + return underlying.get(namespace, address, c); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> address, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + return underlying.get(namespace, address, c); + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindKeyedCombiningValue( StateTag<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> address, http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java ---------------------------------------------------------------------- diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java index c8eb66e..c7409bb 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/CopyOnAccessInMemoryStateInternalsTest.java @@ -23,6 +23,7 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.theInstance; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -47,6 +48,8 @@ import org.apache.beam.sdk.transforms.windowing.OutputTimeFns; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; import org.apache.beam.sdk.util.state.CombiningState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.ValueState; import org.apache.beam.sdk.util.state.WatermarkHoldState; import org.joda.time.Instant; @@ -164,6 +167,61 @@ public class CopyOnAccessInMemoryStateInternalsTest { } @Test + public void testSetStateWithUnderlying() { + CopyOnAccessInMemoryStateInternals<String> underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag<Object, SetState<Integer>> valueTag = StateTags.set("foo", VarIntCoder.of()); + SetState<Integer> underlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), emptyIterable()); + + underlyingValue.add(1); + assertThat(underlyingValue.read(), containsInAnyOrder(1)); + + CopyOnAccessInMemoryStateInternals<String> internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + SetState<Integer> copyOnAccessState = internals.state(namespace, valueTag); + assertThat(copyOnAccessState.read(), containsInAnyOrder(1)); + + copyOnAccessState.add(4); + assertThat(copyOnAccessState.read(), containsInAnyOrder(4, 1)); + assertThat(underlyingValue.read(), containsInAnyOrder(1)); + + SetState<Integer> reReadUnderlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.read(), equalTo(reReadUnderlyingValue.read())); + } + + @Test + public void testMapStateWithUnderlying() { + CopyOnAccessInMemoryStateInternals<String> underlying = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); + + StateNamespace namespace = new StateNamespaceForTest("foo"); + StateTag<Object, MapState<String, Integer>> valueTag = + StateTags.map("foo", StringUtf8Coder.of(), VarIntCoder.of()); + MapState<String, Integer> underlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.iterate(), emptyIterable()); + + underlyingValue.put("hello", 1); + assertThat(underlyingValue.get("hello"), equalTo(1)); + + CopyOnAccessInMemoryStateInternals<String> internals = + CopyOnAccessInMemoryStateInternals.withUnderlying(key, underlying); + MapState<String, Integer> copyOnAccessState = internals.state(namespace, valueTag); + assertThat(copyOnAccessState.get("hello"), equalTo(1)); + + copyOnAccessState.put("world", 4); + assertThat(copyOnAccessState.get("hello"), equalTo(1)); + assertThat(copyOnAccessState.get("world"), equalTo(4)); + assertThat(underlyingValue.get("hello"), equalTo(1)); + assertNull(underlyingValue.get("world")); + + MapState<String, Integer> reReadUnderlyingValue = underlying.state(namespace, valueTag); + assertThat(underlyingValue.iterate(), equalTo(reReadUnderlyingValue.iterate())); + } + + @Test public void testAccumulatorCombiningStateWithUnderlying() throws CannotProvideCoderException { CopyOnAccessInMemoryStateInternals<String> underlying = CopyOnAccessInMemoryStateInternals.withUnderlying(key, null); http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java ---------------------------------------------------------------------- diff --git a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java index eaededb..4183067 100644 --- a/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java +++ b/runners/flink/runner/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/FlinkStateInternals.java @@ -37,7 +37,9 @@ import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.CombineContextFactory; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; import org.apache.beam.sdk.util.state.ReadableState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.State; import org.apache.beam.sdk.util.state.StateContext; import org.apache.beam.sdk.util.state.StateContexts; @@ -125,6 +127,22 @@ public class FlinkStateInternals<K> implements StateInternals<K> { } @Override + public <T> SetState<T> bindSet( + StateTag<? super K, SetState<T>> address, + Coder<T> elemCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", SetState.class.getSimpleName())); + } + + @Override + public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + StateTag<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder) { + throw new UnsupportedOperationException( + String.format("%s is not supported", MapState.class.getSimpleName())); + } + + @Override public <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java new file mode 100644 index 0000000..85d99d6 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/MapState.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util.state; + +import java.util.Map; + +/** + * An object that maps keys to values. + * A map cannot contain duplicate keys; + * each key can map to at most one value. + * + * @param <K> the type of keys maintained by this map + * @param <V> the type of mapped values + */ +public interface MapState<K, V> extends State { + + /** + * Returns the value to which the specified key is mapped in the state. + */ + V get(K key); + + /** + * Associates the specified value with the specified key in this state. + */ + void put(K key, V value); + + /** + * If the specified key is not already associated with a value (or is mapped + * to {@code null}) associates it with the given value and returns + * {@code null}, else returns the current value. + */ + V putIfAbsent(K key, V value); + + /** + * Removes the mapping for a key from this map if it is present. + */ + void remove(K key); + + /** + * A bulk get. + * @param keys the keys to search for + * @return a iterable view of values, maybe some values is null. + * The order of values corresponds to the order of the keys. + */ + Iterable<V> get(Iterable<K> keys); + + /** + * Indicate that specified key will be read later. + */ + MapState<K, V> getLater(K k); + + /** + * Indicate that specified batch keys will be read later. + */ + MapState<K, V> getLater(Iterable<K> keys); + + /** + * Returns a iterable view of the keys contained in this map. + */ + Iterable<K> keys(); + + /** + * Returns a iterable view of the values contained in this map. + */ + Iterable<V> values(); + + /** + * Indicate that all key-values will be read later. + */ + MapState<K, V> iterateLater(); + + /** + * Returns a iterable view of all key-values. + */ + Iterable<Map.Entry<K, V>> iterate(); + +} + http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java new file mode 100644 index 0000000..93058b2 --- /dev/null +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/SetState.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.util.state; + +/** + * State containing no duplicate elements. + * Items can be added to the set and the contents read out. + * + * @param <T> The type of elements in the set. + */ +public interface SetState<T> extends CombiningState<T, Iterable<T>> { + /** + * Returns true if this set contains the specified element. + */ + boolean contains(T t); + + /** + * Add a value to the buffer if it is not already present. + * If this set already contains the element, the call leaves the set + * unchanged and returns false. + */ + boolean addIfAbsent(T t); + + /** + * Removes the specified element from this set if it is present. + */ + void remove(T t); + + /** + * Indicate that elements will be read later. + * @param elements to be read later + * @return this for convenient chaining + */ + SetState<T> readLater(Iterable<T> elements); + + /** + * <p>Checks if SetState contains any given elements.</p> + * + * @param elements the elements to search for + * @return the {@code true} if any of the elements are found, + * {@code false} if no match + */ + boolean containsAny(Iterable<T> elements); + + /** + * <p>Checks if SetState contains all given elements.</p> + * + * @param elements the elements to find + * @return true if the SetState contains all elements, + * false if not + */ + boolean containsAll(Iterable<T> elements); + + @Override + SetState<T> readLater(); +} http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java index 0521e15..fbfb475 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateBinder.java @@ -33,6 +33,12 @@ public interface StateBinder<K> { <T> BagState<T> bindBag(String id, StateSpec<? super K, BagState<T>> spec, Coder<T> elemCoder); + <T> SetState<T> bindSet(String id, StateSpec<? super K, SetState<T>> spec, Coder<T> elemCoder); + + <KeyT, ValueT> MapState<KeyT, ValueT> bindMap( + String id, StateSpec<? super K, MapState<KeyT, ValueT>> spec, + Coder<KeyT> mapKeyCoder, Coder<ValueT> mapValueCoder); + <InputT, AccumT, OutputT> AccumulatorCombiningState<InputT, AccumT, OutputT> bindCombiningValue( String id, StateSpec<? super K, AccumulatorCombiningState<InputT, AccumT, OutputT>> spec, http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java index 08c3a12..8912993 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/state/StateSpecs.java @@ -125,6 +125,21 @@ public class StateSpecs { return new BagStateSpec<T>(elemCoder); } + /** + * Create a state spec that supporting for {@link java.util.Set} like access patterns. + */ + public static <T> StateSpec<Object, SetState<T>> set(Coder<T> elemCoder) { + return new SetStateSpec<>(elemCoder); + } + + /** + * Create a state spec that supporting for {@link java.util.Map} like access patterns. + */ + public static <K, V> StateSpec<Object, MapState<K, V>> map(Coder<K> keyCoder, + Coder<V> valueCoder) { + return new MapStateSpec<>(keyCoder, valueCoder); + } + /** Create a state spec for holding the watermark. */ public static <W extends BoundedWindow> StateSpec<Object, WatermarkHoldState<W>> watermarkStateInternal( @@ -346,6 +361,80 @@ public class StateSpecs { } } + private static class MapStateSpec<K, V> implements StateSpec<Object, MapState<K, V>> { + + private final Coder<K> keyCoder; + private final Coder<V> valueCoder; + + private MapStateSpec(Coder<K> keyCoder, Coder<V> valueCoder) { + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + } + + @Override + public MapState<K, V> bind(String id, StateBinder<?> visitor) { + return visitor.bindMap(id, this, keyCoder, valueCoder); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof MapStateSpec)) { + return false; + } + + MapStateSpec<?, ?> that = (MapStateSpec<?, ?>) obj; + return Objects.equals(this.keyCoder, that.keyCoder) + && Objects.equals(this.valueCoder, that.valueCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), keyCoder, valueCoder); + } + } + + /** + * A specification for a state cell supporting for set-like access patterns. + * + * <p>Includes the coder for the element type {@code T}</p> + */ + private static class SetStateSpec<T> implements StateSpec<Object, SetState<T>> { + + private final Coder<T> elemCoder; + + private SetStateSpec(Coder<T> elemCoder) { + this.elemCoder = elemCoder; + } + + @Override + public SetState<T> bind(String id, StateBinder<?> visitor) { + return visitor.bindSet(id, this, elemCoder); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof SetStateSpec)) { + return false; + } + + SetStateSpec<?> that = (SetStateSpec<?>) obj; + return Objects.equals(this.elemCoder, that.elemCoder); + } + + @Override + public int hashCode() { + return Objects.hash(getClass(), elemCoder); + } + } + /** * A specification for a state cell tracking a combined watermark hold. * http://git-wip-us.apache.org/repos/asf/beam/blob/a0702f5b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java ---------------------------------------------------------------------- diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 75c39cc..f40bbe1 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -40,6 +40,7 @@ import com.fasterxml.jackson.annotation.JsonCreator; import com.google.common.base.MoreObjects; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; @@ -48,6 +49,8 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; +import java.util.Set; import org.apache.beam.sdk.Pipeline.PipelineExecutionException; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.CoderException; @@ -60,6 +63,8 @@ import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.RunnableOnService; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.testing.TestStream; +import org.apache.beam.sdk.testing.UsesMapState; +import org.apache.beam.sdk.testing.UsesSetState; import org.apache.beam.sdk.testing.UsesStatefulParDo; import org.apache.beam.sdk.testing.UsesTestStream; import org.apache.beam.sdk.testing.UsesTimersInParDo; @@ -83,6 +88,8 @@ import org.apache.beam.sdk.util.TimerSpecs; import org.apache.beam.sdk.util.common.ElementByteSizeObserver; import org.apache.beam.sdk.util.state.AccumulatorCombiningState; import org.apache.beam.sdk.util.state.BagState; +import org.apache.beam.sdk.util.state.MapState; +import org.apache.beam.sdk.util.state.SetState; import org.apache.beam.sdk.util.state.StateSpec; import org.apache.beam.sdk.util.state.StateSpecs; import org.apache.beam.sdk.util.state.ValueState; @@ -1681,6 +1688,93 @@ public class ParDoTest implements Serializable { } @Test + @Category({RunnableOnService.class, UsesStatefulParDo.class, UsesSetState.class}) + public void testSetState() { + final String stateId = "foo"; + final String countStateId = "count"; + + DoFn<KV<String, Integer>, Set<Integer>> fn = + new DoFn<KV<String, Integer>, Set<Integer>>() { + + @StateId(stateId) + private final StateSpec<Object, SetState<Integer>> setState = + StateSpecs.set(VarIntCoder.of()); + @StateId(countStateId) + private final StateSpec<Object, AccumulatorCombiningState<Integer, int[], Integer>> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, + @StateId(stateId) SetState<Integer> state, + @StateId(countStateId) AccumulatorCombiningState<Integer, int[], Integer> + count) { + state.add(c.element().getValue()); + count.add(1); + if (count.read() >= 4) { + Set<Integer> set = Sets.newHashSet(state.read()); + c.output(set); + } + } + }; + + PCollection<Set<Integer>> output = + pipeline.apply( + Create.of( + KV.of("hello", 97), KV.of("hello", 42), KV.of("hello", 42), KV.of("hello", 12))) + .apply(ParDo.of(fn)); + + PAssert.that(output).containsInAnyOrder(Sets.newHashSet(97, 42, 12)); + pipeline.run(); + } + + @Test + @Category({RunnableOnService.class, UsesStatefulParDo.class, UsesMapState.class}) + public void testMapState() { + final String stateId = "foo"; + final String countStateId = "count"; + + DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>> fn = + new DoFn<KV<String, KV<String, Integer>>, KV<String, Integer>>() { + + @StateId(stateId) + private final StateSpec<Object, MapState<String, Integer>> mapState = + StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()); + @StateId(countStateId) + private final StateSpec<Object, AccumulatorCombiningState<Integer, int[], Integer>> + countState = StateSpecs.combiningValueFromInputInternal(VarIntCoder.of(), + Sum.ofIntegers()); + + @ProcessElement + public void processElement( + ProcessContext c, @StateId(stateId) MapState<String, Integer> state, + @StateId(countStateId) AccumulatorCombiningState<Integer, int[], Integer> + count) { + KV<String, Integer> value = c.element().getValue(); + state.put(value.getKey(), value.getValue()); + count.add(1); + if (count.read() >= 4) { + Iterable<Map.Entry<String, Integer>> iterate = state.iterate(); + for (Map.Entry<String, Integer> entry : iterate) { + c.output(KV.of(entry.getKey(), entry.getValue())); + } + } + } + }; + + PCollection<KV<String, Integer>> output = + pipeline.apply( + Create.of( + KV.of("hello", KV.of("a", 97)), KV.of("hello", KV.of("b", 42)), + KV.of("hello", KV.of("b", 42)), KV.of("hello", KV.of("c", 12)))) + .apply(ParDo.of(fn)); + + PAssert.that(output).containsInAnyOrder(KV.of("a", 97), KV.of("b", 42), KV.of("c", 12)); + pipeline.run(); + } + + @Test @Category({RunnableOnService.class, UsesStatefulParDo.class}) public void testCombiningState() { final String stateId = "foo";