This is an automated email from the ASF dual-hosted git repository. reuvenlax pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 488c10c Merge pull request #12864: [BEAM-10650] Windmill implementation for TimestampOrderedState 488c10c is described below commit 488c10c23312ae14000e292efe835d273e67883a Author: reuvenlax <re...@google.com> AuthorDate: Sat Oct 17 09:43:19 2020 -0700 Merge pull request #12864: [BEAM-10650] Windmill implementation for TimestampOrderedState --- .../beam/runners/dataflow/DataflowRunner.java | 7 - .../dataflow/worker/WindmillStateInternals.java | 588 +++++++++++++++++++-- .../dataflow/worker/WindmillStateReader.java | 397 +++++++++----- .../worker/WindmillStateInternalsTest.java | 346 ++++++++++++ .../dataflow/worker/WindmillStateReaderTest.java | 262 +++++++++ .../worker/windmill/src/main/proto/windmill.proto | 10 +- .../org/apache/beam/sdk/transforms/ParDoTest.java | 39 +- 7 files changed, 1473 insertions(+), 176 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index 67fcddf..115f4a0 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -121,7 +121,6 @@ import org.apache.beam.sdk.runners.PTransformOverrideFactory; import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.runners.TransformHierarchy.Node; import org.apache.beam.sdk.state.MapState; -import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.transforms.Combine; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -2046,12 +2045,6 @@ public class DataflowRunner extends PipelineRunner<DataflowPipelineJob> { "%s does not currently support %s", DataflowRunner.class.getSimpleName(), MapState.class.getSimpleName())); } - if (DoFnSignatures.usesOrderedListState(fn)) { - throw new UnsupportedOperationException( - String.format( - "%s does not currently support %s", - DataflowRunner.class.getSimpleName(), OrderedListState.class.getSimpleName())); - } if (streaming && DoFnSignatures.requiresTimeSortedInput(fn)) { throw new UnsupportedOperationException( String.format( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java index 2ad11f4..73e3c72 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java @@ -17,18 +17,25 @@ */ package org.apache.beam.runners.dataflow.worker; +import com.google.auto.value.AutoValue; import java.io.Closeable; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.io.OutputStreamWriter; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Comparator; import java.util.Iterator; import java.util.List; +import java.util.Map; import java.util.Random; +import java.util.SortedSet; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.function.BiConsumer; import javax.annotation.concurrent.NotThreadSafe; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; @@ -38,8 +45,22 @@ import org.apache.beam.runners.core.StateTag.StateBinder; import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.dataflow.worker.WindmillStateCache.ForKey; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListDeleteRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListInsertRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListUpdateRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.Coder.Context; +import org.apache.beam.sdk.coders.CoderException; +import org.apache.beam.sdk.coders.CustomCoder; +import org.apache.beam.sdk.coders.InstantCoder; +import org.apache.beam.sdk.coders.MapCoder; +import org.apache.beam.sdk.coders.NullableCoder; +import org.apache.beam.sdk.coders.SetCoder; +import org.apache.beam.sdk.coders.StructuredCoder; +import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; @@ -60,11 +81,21 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Optional; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.BoundType; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterators; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Range; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.RangeSet; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.TreeRangeSet; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Futures; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; import org.joda.time.Instant; /** Implementation of {@link StateInternals} using Windmill to manage the underlying data. */ @@ -86,22 +117,28 @@ class WindmillStateInternals<K> implements StateInternals { private final String stateFamily; private final WindmillStateReader reader; private final WindmillStateCache.ForKey cache; + private final boolean isSystemTable; boolean isNewKey; private final Supplier<Closeable> scopedReadStateSupplier; + private final StateTable derivedStateTable; public CachingStateTable( @Nullable K key, String stateFamily, WindmillStateReader reader, WindmillStateCache.ForKey cache, + boolean isSystemTable, boolean isNewKey, - Supplier<Closeable> scopedReadStateSupplier) { + Supplier<Closeable> scopedReadStateSupplier, + StateTable derivedStateTable) { this.key = key; this.stateFamily = stateFamily; this.reader = reader; this.cache = cache; + this.isSystemTable = isSystemTable; this.isNewKey = isNewKey; this.scopedReadStateSupplier = scopedReadStateSupplier; + this.derivedStateTable = derivedStateTable != null ? derivedStateTable : this; } @Override @@ -112,6 +149,9 @@ class WindmillStateInternals<K> implements StateInternals { return new StateBinder() { @Override public <T> BagState<T> bindBag(StateTag<BagState<T>> address, Coder<T> elemCoder) { + if (isSystemTable) { + address = StateTags.makeSystemTagInternal(address); + } WindmillBag<T> result = (WindmillBag<T>) cache.get(namespace, address); if (result == null) { result = new WindmillBag<>(namespace, address, stateFamily, elemCoder, isNewKey); @@ -138,9 +178,14 @@ class WindmillStateInternals<K> implements StateInternals { @Override public <T> OrderedListState<T> bindOrderedList( StateTag<OrderedListState<T>> spec, Coder<T> elemCoder) { + if (isSystemTable) { + spec = StateTags.makeSystemTagInternal(spec); + } WindmillOrderedList<T> result = (WindmillOrderedList<T>) cache.get(namespace, spec); if (result == null) { - result = new WindmillOrderedList<>(namespace, spec, stateFamily, elemCoder, isNewKey); + result = + new WindmillOrderedList<>( + derivedStateTable, namespace, spec, stateFamily, elemCoder, isNewKey); } result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; @@ -149,6 +194,9 @@ class WindmillStateInternals<K> implements StateInternals { @Override public WatermarkHoldState bindWatermark( StateTag<WatermarkHoldState> address, TimestampCombiner timestampCombiner) { + if (isSystemTable) { + address = StateTags.makeSystemTagInternal(address); + } WindmillWatermarkHold result = (WindmillWatermarkHold) cache.get(namespace, address); if (result == null) { result = @@ -164,8 +212,11 @@ class WindmillStateInternals<K> implements StateInternals { StateTag<CombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, CombineFn<InputT, AccumT, OutputT> combineFn) { + if (isSystemTable) { + address = StateTags.makeSystemTagInternal(address); + } WindmillCombiningState<InputT, AccumT, OutputT> result = - new WindmillCombiningState<InputT, AccumT, OutputT>( + new WindmillCombiningState<>( namespace, address, stateFamily, accumCoder, combineFn, cache, isNewKey); result.initializeForWorkItem(reader, scopedReadStateSupplier); return result; @@ -177,11 +228,17 @@ class WindmillStateInternals<K> implements StateInternals { StateTag<CombiningState<InputT, AccumT, OutputT>> address, Coder<AccumT> accumCoder, CombineFnWithContext<InputT, AccumT, OutputT> combineFn) { + if (isSystemTable) { + address = StateTags.makeSystemTagInternal(address); + } return bindCombiningValue(address, accumCoder, CombineFnUtil.bindContext(combineFn, c)); } @Override public <T> ValueState<T> bindValue(StateTag<ValueState<T>> address, Coder<T> coder) { + if (isSystemTable) { + address = StateTags.makeSystemTagInternal(address); + } WindmillValue<T> result = (WindmillValue<T>) cache.get(namespace, address); if (result == null) { result = new WindmillValue<>(namespace, address, stateFamily, coder, isNewKey); @@ -196,6 +253,7 @@ class WindmillStateInternals<K> implements StateInternals { private WindmillStateCache.ForKey cache; Supplier<Closeable> scopedReadStateSupplier; private StateTable workItemState; + private StateTable workItemDerivedState; public WindmillStateInternals( @Nullable K key, @@ -207,16 +265,23 @@ class WindmillStateInternals<K> implements StateInternals { this.key = key; this.cache = cache; this.scopedReadStateSupplier = scopedReadStateSupplier; + this.workItemDerivedState = + new CachingStateTable<>( + key, stateFamily, reader, cache, true, isNewKey, scopedReadStateSupplier, null); this.workItemState = - new CachingStateTable<K>( - key, stateFamily, reader, cache, isNewKey, scopedReadStateSupplier); + new CachingStateTable<>( + key, + stateFamily, + reader, + cache, + false, + isNewKey, + scopedReadStateSupplier, + workItemDerivedState); } - public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder) { - List<Future<WorkItemCommitRequest>> commitsToMerge = new ArrayList<>(); - - // Call persist on each first, which may schedule some futures for reading. - for (State location : workItemState.values()) { + private void persist(List<Future<WorkItemCommitRequest>> commitsToMerge, StateTable stateTable) { + for (State location : stateTable.values()) { if (!(location instanceof WindmillState)) { throw new IllegalStateException( String.format( @@ -235,12 +300,20 @@ class WindmillStateInternals<K> implements StateInternals { // Clear any references to the underlying reader to prevent space leaks. // The next work unit to use these cached State objects will reset the // reader to a current reader in case those values are modified. - for (State location : workItemState.values()) { + for (State location : stateTable.values()) { ((WindmillState) location).cleanupAfterWorkItem(); } // Clear out the map of already retrieved state instances. - workItemState.clear(); + stateTable.clear(); + } + + public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder) { + List<Future<WorkItemCommitRequest>> commitsToMerge = new ArrayList<>(); + + // Call persist on each first, which may schedule some futures for reading. + persist(commitsToMerge, workItemState); + persist(commitsToMerge, workItemDerivedState); try (Closeable scope = scopedReadStateSupplier.get()) { for (Future<WorkItemCommitRequest> commitFuture : commitsToMerge) { @@ -470,16 +543,305 @@ class WindmillStateInternals<K> implements StateInternals { } } - private static class WindmillOrderedList<T> extends SimpleWindmillState - implements OrderedListState<T> { + // Coder for closed-open ranges. + private static class RangeCoder<T extends Comparable> extends StructuredCoder<Range<T>> { + private Coder<T> boundCoder; + + RangeCoder(Coder<T> boundCoder) { + this.boundCoder = NullableCoder.of(boundCoder); + } + + @Override + public List<? extends Coder<?>> getCoderArguments() { + return Lists.newArrayList(boundCoder); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + boundCoder.verifyDeterministic(); + ; + } + + @Override + public void encode(Range<T> value, OutputStream outStream) throws CoderException, IOException { + Preconditions.checkState( + value.lowerBoundType().equals(BoundType.CLOSED), "unexpected range " + value); + Preconditions.checkState( + value.upperBoundType().equals(BoundType.OPEN), "unexpected range " + value); + boundCoder.encode(value.hasLowerBound() ? value.lowerEndpoint() : null, outStream); + boundCoder.encode(value.hasUpperBound() ? value.upperEndpoint() : null, outStream); + } + + @Override + public Range<T> decode(InputStream inStream) throws CoderException, IOException { + @Nullable T lower = boundCoder.decode(inStream); + @Nullable T upper = boundCoder.decode(inStream); + if (lower == null) { + return upper != null ? Range.lessThan(upper) : Range.all(); + } else if (upper == null) { + return Range.atLeast(lower); + } else { + return Range.closedOpen(lower, upper); + } + } + } + + private static class RangeSetCoder<T extends Comparable> extends CustomCoder<RangeSet<T>> { + private SetCoder<Range<T>> rangesCoder; + + RangeSetCoder(Coder<T> boundCoder) { + this.rangesCoder = SetCoder.of(new RangeCoder<>(boundCoder)); + } + + @Override + public void encode(RangeSet<T> value, OutputStream outStream) throws IOException { + rangesCoder.encode(value.asRanges(), outStream); + } + + @Override + public RangeSet<T> decode(InputStream inStream) throws CoderException, IOException { + return TreeRangeSet.create(rangesCoder.decode(inStream)); + } + } + + /** + * Tracker for the ids used in an ordered list. + * + * <p>Windmill accepts an int64 id for each timestamped-element in the list. Unique elements are + * identified by the pair of timestamp and id. This means that tow unique elements e1, e2 must + * have different (ts1, id1), (ts2, id2) pairs. To accomplish this we bucket time into five-minute + * buckets, and store a free list of ids available for each bucket. + * + * <p>When a timestamp range is deleted, we remove id tracking for elements in that range. In + * order to handle the case where a range is deleted piecemeal, we track sub-range deletions for + * each range. For example: + * + * <p>12:00 - 12:05 ids 12:05 - 12:10 ids + * + * <p>delete 12:00-12:06 + * + * <p>12:00 - 12:05 *removed* 12:05 - 12:10 ids subranges deleted 12:05-12:06 + * + * <p>delete 12:06 - 12:07 + * + * <p>12:05 - 12:10 ids subranges deleted 12:05-12:07 + * + * <p>delete 12:07 - 12:10 + * + * <p>12:05 - 12:10 *removed* + */ + static final class IdTracker { + static final String IDS_AVAILABLE_STR = "IdsAvailable"; + static final String DELETIONS_STR = "Deletions"; + + static final long MIN_ID = Long.MIN_VALUE; + static final long MAX_ID = Long.MAX_VALUE; + + // We track ids on five-minute boundaries. + private static final Duration RESOLUTION = Duration.standardMinutes(5); + static final MapCoder<Range<Instant>, RangeSet<Long>> IDS_AVAILABLE_CODER = + MapCoder.of(new RangeCoder<>(InstantCoder.of()), new RangeSetCoder<>(VarLongCoder.of())); + static final MapCoder<Range<Instant>, RangeSet<Instant>> SUBRANGE_DELETIONS_CODER = + MapCoder.of(new RangeCoder<>(InstantCoder.of()), new RangeSetCoder<>(InstantCoder.of())); + private final StateTag<ValueState<Map<Range<Instant>, RangeSet<Long>>>> idsAvailableTag; + // A map from five-minute ranges to the set of ids available in that interval. + final ValueState<Map<Range<Instant>, RangeSet<Long>>> idsAvailableValue; + private final StateTag<ValueState<Map<Range<Instant>, RangeSet<Instant>>>> subRangeDeletionsTag; + // If a timestamp-range in the map has been partially cleared, the cleared intervals are stored + // here. + final ValueState<Map<Range<Instant>, RangeSet<Instant>>> subRangeDeletionsValue; + + IdTracker( + StateTable stateTable, + StateNamespace namespace, + StateTag<?> spec, + String stateFamily, + boolean complete) { + this.idsAvailableTag = + StateTags.makeSystemTagInternal( + StateTags.value(spec.getId() + IDS_AVAILABLE_STR, IDS_AVAILABLE_CODER)); + this.idsAvailableValue = + stateTable.get(namespace, idsAvailableTag, StateContexts.nullContext()); + this.subRangeDeletionsTag = + StateTags.makeSystemTagInternal( + StateTags.value(spec.getId() + DELETIONS_STR, SUBRANGE_DELETIONS_CODER)); + this.subRangeDeletionsValue = + stateTable.get(namespace, subRangeDeletionsTag, StateContexts.nullContext()); + } + + static <ValueT extends Comparable<? super ValueT>> + Map<Range<Instant>, RangeSet<ValueT>> newSortedRangeMap(Class<ValueT> valueClass) { + return Maps.newTreeMap( + Comparator.<Range<Instant>, Instant>comparing(Range::lowerEndpoint) + .thenComparing(Range::upperEndpoint)); + } + + private Range<Instant> getTrackedRange(Instant ts) { + Instant snapped = + new Instant(ts.getMillis() - ts.plus(RESOLUTION).getMillis() % RESOLUTION.getMillis()); + return Range.closedOpen(snapped, snapped.plus(RESOLUTION)); + } + + @SuppressWarnings("FutureReturnValueIgnored") + void readLater() { + idsAvailableValue.readLater(); + subRangeDeletionsValue.readLater(); + } + + Map<Range<Instant>, RangeSet<Long>> readIdsAvailable() { + Map<Range<Instant>, RangeSet<Long>> idsAvailable = idsAvailableValue.read(); + return idsAvailable != null ? idsAvailable : newSortedRangeMap(Long.class); + } + + Map<Range<Instant>, RangeSet<Instant>> readSubRangeDeletions() { + Map<Range<Instant>, RangeSet<Instant>> subRangeDeletions = subRangeDeletionsValue.read(); + return subRangeDeletions != null ? subRangeDeletions : newSortedRangeMap(Instant.class); + } + + void clear() throws ExecutionException, InterruptedException { + idsAvailableValue.clear(); + subRangeDeletionsValue.clear(); + } + + <T> void add( + SortedSet<TimestampedValueWithId<T>> elements, BiConsumer<TimestampedValue<T>, Long> output) + throws ExecutionException, InterruptedException { + Range<Long> currentIdRange = null; + long currentId = 0; + + Range<Instant> currentTsRange = null; + RangeSet<Instant> currentTsRangeDeletions = null; + + Map<Range<Instant>, RangeSet<Long>> idsAvailable = readIdsAvailable(); + Map<Range<Instant>, RangeSet<Instant>> subRangeDeletions = readSubRangeDeletions(); + + RangeSet<Long> availableIdsForTsRange = null; + Iterator<Range<Long>> idRangeIter = null; + RangeSet<Long> idsUsed = TreeRangeSet.create(); + for (TimestampedValueWithId<T> pendingAdd : elements) { + // Since elements are in increasing ts order, often we'll be able to reuse the previous + // iteration's range. + if (currentTsRange == null + || !currentTsRange.contains(pendingAdd.getValue().getTimestamp())) { + if (availableIdsForTsRange != null) { + // We're moving onto a new ts range. Remove all used ids + availableIdsForTsRange.removeAll(idsUsed); + idsUsed = TreeRangeSet.create(); + } + + // Lookup the range for the current timestamp. + currentTsRange = getTrackedRange(pendingAdd.getValue().getTimestamp()); + // Lookup available ids for this timestamp range. If nothing there, we default to all ids + // available. + availableIdsForTsRange = + idsAvailable.computeIfAbsent( + currentTsRange, + r -> TreeRangeSet.create(ImmutableList.of(Range.closedOpen(MIN_ID, MAX_ID)))); + idRangeIter = availableIdsForTsRange.asRanges().iterator(); + currentIdRange = null; + currentTsRangeDeletions = subRangeDeletions.get(currentTsRange); + } + + if (currentIdRange == null || currentId >= currentIdRange.upperEndpoint()) { + // Move to the next range of free ids, and start assigning ranges from there. + currentIdRange = idRangeIter.next(); + currentId = currentIdRange.lowerEndpoint(); + } + + if (currentTsRangeDeletions != null) { + currentTsRangeDeletions.remove( + Range.closedOpen( + pendingAdd.getValue().getTimestamp(), + pendingAdd.getValue().getTimestamp().plus(1))); + } + idsUsed.add(Range.closedOpen(currentId, currentId + 1)); + output.accept(pendingAdd.getValue(), currentId++); + } + if (availableIdsForTsRange != null) { + availableIdsForTsRange.removeAll(idsUsed); + } + idsAvailableValue.write(idsAvailable); + subRangeDeletionsValue.write(subRangeDeletions); + } + + // Remove a timestamp range. Returns ids freed up. + void remove(Range<Instant> tsRange) throws ExecutionException, InterruptedException { + Map<Range<Instant>, RangeSet<Long>> idsAvailable = readIdsAvailable(); + Map<Range<Instant>, RangeSet<Instant>> subRangeDeletions = readSubRangeDeletions(); + + for (Range<Instant> current = getTrackedRange(tsRange.lowerEndpoint()); + current.lowerEndpoint().isBefore(tsRange.upperEndpoint()); + current = getTrackedRange(current.lowerEndpoint().plus(RESOLUTION))) { + // TODO(reuvenlax): shouldn't need to iterate over all ranges. + boolean rangeCleared; + if (!tsRange.encloses(current)) { + // This can happen if the beginning or the end of tsRange doesn't fall on a RESOLUTION + // boundary. Since we + // are deleting a portion of a tracked range, track what we are deleting. + RangeSet<Instant> rangeDeletions = + subRangeDeletions.computeIfAbsent(current, r -> TreeRangeSet.create()); + rangeDeletions.add(tsRange.intersection(current)); + // If we ended up deleting the whole range, than we can simply remove it from the tracking + // map. + rangeCleared = rangeDeletions.encloses(current); + } else { + rangeCleared = true; + } + if (rangeCleared) { + // Remove the range from both maps. + idsAvailable.remove(current); + subRangeDeletions.remove(current); + } + } + idsAvailableValue.write(idsAvailable); + subRangeDeletionsValue.write(subRangeDeletions); + } + } + + @AutoValue + abstract static class TimestampedValueWithId<T> { + private static final Comparator<TimestampedValueWithId<?>> COMPARATOR = + Comparator.<TimestampedValueWithId<?>, Instant>comparing(v -> v.getValue().getTimestamp()) + .thenComparingLong(TimestampedValueWithId::getId); + + abstract TimestampedValue<T> getValue(); + + abstract long getId(); + + static <T> TimestampedValueWithId<T> of(TimestampedValue<T> value, long id) { + return new AutoValue_WindmillStateInternals_TimestampedValueWithId<>(value, id); + } + static <T> TimestampedValueWithId<T> bound(Instant ts) { + return of(TimestampedValue.of(null, ts), Long.MIN_VALUE); + } + } + + static class WindmillOrderedList<T> extends SimpleWindmillState implements OrderedListState<T> { private final StateNamespace namespace; private final StateTag<OrderedListState<T>> spec; private final ByteString stateKey; private final String stateFamily; private final Coder<T> elemCoder; + private boolean complete; + private boolean cleared = false; + // We need to sort based on timestamp, but we need objects with the same timestamp to be treated + // as unique. We can't use a MultiSet as we can't construct a comparator that uniquely + // identifies objects, + // so we construct a unique in-memory long ids for each element. + private SortedSet<TimestampedValueWithId<T>> pendingAdds = + Sets.newTreeSet(TimestampedValueWithId.COMPARATOR); + + private RangeSet<Instant> pendingDeletes = TreeRangeSet.create(); + private IdTracker idTracker; + + // The default proto values for SortedListRange correspond to the minimum and maximum + // timestamps. + static final long MIN_TS_MICROS = SortedListRange.getDefaultInstance().getStart(); + static final long MAX_TS_MICROS = SortedListRange.getDefaultInstance().getLimit(); private WindmillOrderedList( + StateTable derivedStateTable, StateNamespace namespace, StateTag<OrderedListState<T>> spec, String stateFamily, @@ -487,64 +849,226 @@ class WindmillStateInternals<K> implements StateInternals { boolean isNewKey) { this.namespace = namespace; this.spec = spec; + this.stateKey = encodeKey(namespace, spec); this.stateFamily = stateFamily; this.elemCoder = elemCoder; + this.complete = isNewKey; + this.idTracker = new IdTracker(derivedStateTable, namespace, spec, stateFamily, complete); } @Override public Iterable<TimestampedValue<T>> read() { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + return readRange(null, null); + } + + private SortedSet<TimestampedValueWithId<T>> getPendingAddRange( + @Nullable Instant minTimestamp, @Nullable Instant limitTimestamp) { + SortedSet<TimestampedValueWithId<T>> pendingInRange = pendingAdds; + if (minTimestamp != null && limitTimestamp != null) { + pendingInRange = + pendingInRange.subSet( + TimestampedValueWithId.bound(minTimestamp), + TimestampedValueWithId.bound(limitTimestamp)); + } else if (minTimestamp == null && limitTimestamp != null) { + pendingInRange = pendingInRange.headSet(TimestampedValueWithId.bound(limitTimestamp)); + } else if (limitTimestamp == null && minTimestamp != null) { + pendingInRange = pendingInRange.tailSet(TimestampedValueWithId.bound(minTimestamp)); + } + return pendingInRange; } @Override - public Iterable<TimestampedValue<T>> readRange(Instant minTimestamp, Instant limitTimestamp) { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + public Iterable<TimestampedValue<T>> readRange( + @Nullable Instant minTimestamp, @Nullable Instant limitTimestamp) { + idTracker.readLater(); + + final Future<Iterable<TimestampedValue<T>>> future = getFuture(minTimestamp, limitTimestamp); + try (Closeable scope = scopedReadState()) { + SortedSet<TimestampedValueWithId<T>> pendingInRange = + getPendingAddRange(minTimestamp, limitTimestamp); + + // Transform the return iterator so it has the same type as pendingAdds. We need to ensure + // that the ids don't overlap with any in pendingAdds, so begin with pendingAdds.size(). + Iterable<TimestampedValueWithId<T>> data = + new Iterable<TimestampedValueWithId<T>>() { + private Iterable<TimestampedValue<T>> iterable = future.get(); + + @Override + public Iterator<TimestampedValueWithId<T>> iterator() { + return new Iterator<TimestampedValueWithId<T>>() { + private Iterator<TimestampedValue<T>> iter = iterable.iterator(); + private long currentId = pendingAdds.size(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public TimestampedValueWithId<T> next() { + return TimestampedValueWithId.of(iter.next(), currentId++); + } + }; + } + }; + + Iterable<TimestampedValueWithId<T>> includingAdds = + Iterables.mergeSorted( + ImmutableList.of(data, pendingInRange), TimestampedValueWithId.COMPARATOR); + Iterable<TimestampedValue<T>> fullIterable = + Iterables.filter( + Iterables.transform(includingAdds, TimestampedValueWithId::getValue), + tv -> !pendingDeletes.contains(tv.getTimestamp())); + // TODO(reuvenlax): If we have a known bounded amount of data, cache known ranges. + return fullIterable; + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } } @Override public void clear() { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + cleared = true; + complete = true; + pendingAdds.clear(); + pendingDeletes.clear(); + try { + idTracker.clear(); + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } } @Override public void clearRange(Instant minTimestamp, Instant limitTimestamp) { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + getPendingAddRange(minTimestamp, limitTimestamp).clear(); + pendingDeletes.add(Range.closedOpen(minTimestamp, limitTimestamp)); } @Override public void add(TimestampedValue<T> value) { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + // We use the current size of the container as the in-memory id. This works because + // pendingAdds is completely + // cleared when it is processed (otherwise we could end up with duplicate elements in the same + // container). These + // are not the ids that will be sent to windmill. + pendingAdds.add(TimestampedValueWithId.of(value, pendingAdds.size())); + // Leave pendingDeletes alone. Since we can have multiple values with the same timestamp, we + // may still need + // overlapping deletes to remove previous entries at this timestamp. } @Override public ReadableState<Boolean> isEmpty() { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + return new ReadableState<Boolean>() { + @Override + public ReadableState<Boolean> readLater() { + WindmillOrderedList.this.readLater(); + return this; + } + + @Override + public Boolean read() { + return Iterables.isEmpty(WindmillOrderedList.this.read()); + } + }; } @Override public OrderedListState<T> readLater() { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + return readRangeLater(null, null); } @Override - public OrderedListState<T> readRangeLater(Instant minTimestamp, Instant limitTimestamp) { - throw new UnsupportedOperationException( - String.format("%s is not supported", OrderedListState.class.getSimpleName())); + @SuppressWarnings("FutureReturnValueIgnored") + public OrderedListState<T> readRangeLater( + @Nullable Instant minTimestamp, @Nullable Instant limitTimestamp) { + idTracker.readLater(); + getFuture(minTimestamp, limitTimestamp); + return this; } @Override public WorkItemCommitRequest persistDirectly(ForKey cache) throws IOException { WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder(); + TagSortedListUpdateRequest.Builder updatesBuilder = + commitBuilder.addSortedListUpdatesBuilder().setStateFamily(stateFamily).setTag(stateKey); + try { + if (cleared) { + // Default range. + updatesBuilder.addDeletesBuilder().build(); + cleared = false; + } + + if (!pendingAdds.isEmpty()) { + // TODO(reuvenlax): Once we start caching data, we should remove this line. We have it + // here now + // because once we persist + // added data we forget about it from the cache, so the object is no longer complete. + complete = false; + + TagSortedListInsertRequest.Builder insertBuilder = updatesBuilder.addInsertsBuilder(); + idTracker.add( + pendingAdds, + (elem, id) -> { + try { + ByteString.Output elementStream = ByteString.newOutput(); + elemCoder.encode(elem.getValue(), elementStream, Context.OUTER); + insertBuilder.addEntries( + SortedListEntry.newBuilder() + .setValue(elementStream.toByteString()) + .setSortKey( + WindmillTimeUtils.harnessToWindmillTimestamp(elem.getTimestamp())) + .setId(id)); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + pendingAdds.clear(); + insertBuilder.build(); + } + + if (!pendingDeletes.isEmpty()) { + for (Range<Instant> range : pendingDeletes.asRanges()) { + TagSortedListDeleteRequest.Builder deletesBuilder = updatesBuilder.addDeletesBuilder(); + deletesBuilder.setRange( + SortedListRange.newBuilder() + .setStart(WindmillTimeUtils.harnessToWindmillTimestamp(range.lowerEndpoint())) + .setLimit(WindmillTimeUtils.harnessToWindmillTimestamp(range.upperEndpoint()))); + deletesBuilder.build(); + idTracker.remove(range); + } + pendingDeletes.clear(); + } + } catch (ExecutionException | InterruptedException e) { + throw new RuntimeException(e); + } return commitBuilder.buildPartial(); } + + private Future<Iterable<TimestampedValue<T>>> getFuture( + @Nullable Instant minTimestamp, @Nullable Instant limitTimestamp) { + long startSortKey = + minTimestamp != null + ? WindmillTimeUtils.harnessToWindmillTimestamp(minTimestamp) + : MIN_TS_MICROS; + long limitSortKey = + limitTimestamp != null + ? WindmillTimeUtils.harnessToWindmillTimestamp(limitTimestamp) + : MAX_TS_MICROS; + + if (complete) { + // Right now we don't cache any data, so complete means an empty list. + // TODO(reuvenlax): change this once we start caching data. + return Futures.immediateFuture(Collections.emptyList()); + } + return reader.orderedListFuture( + Range.closedOpen(startSortKey, limitSortKey), stateKey, stateFamily, elemCoder); + } } private static class WindmillBag<T> extends SimpleWindmillState implements BagState<T> { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java index 3c131c6..10ecc6f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java @@ -17,11 +17,15 @@ */ package org.apache.beam.runners.dataflow.worker; +import com.google.api.client.util.Lists; +import com.google.auto.value.AutoValue; +import com.google.common.collect.Iterables; import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; import java.io.IOException; import java.io.InputStream; import java.util.ArrayList; import java.util.Collections; +import java.util.Comparator; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -33,23 +37,28 @@ import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import javax.annotation.Nonnull; +import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagBag; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListFetchRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.Weighted; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Objects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.AbstractIterator; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ForwardingList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Range; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.ForwardingFuture; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Futures; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture; -import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; /** @@ -68,6 +77,12 @@ class WindmillStateReader { public static final long MAX_BAG_BYTES = 8L << 20; // 8MB /** + * Ideal maximum bytes in a TagSortedList response. However, Windmill will always return at least + * one value if possible irrespective of this limit. + */ + public static final long MAX_ORDERED_LIST_BYTES = 8L << 20; // 8MB + + /** * Ideal maximum bytes in a KeyedGetDataResponse. However, Windmill will always return at least * one value if possible irrespective of this limit. */ @@ -77,70 +92,66 @@ class WindmillStateReader { * When combined with a key and computationId, represents the unique address for state managed by * Windmill. */ - private static class StateTag { - private enum Kind { + @AutoValue + abstract static class StateTag<RequestPositionT> { + enum Kind { VALUE, BAG, - WATERMARK; + WATERMARK, + ORDERED_LIST } - private final Kind kind; - private final ByteString tag; - private final String stateFamily; + abstract Kind getKind(); + + abstract ByteString getTag(); + + abstract String getStateFamily(); /** - * For {@link Kind#BAG} kinds: A previous 'continuation_position' returned by Windmill to signal - * the resulting bag was incomplete. Sending that position will request the next page of values. - * Null for first request. + * For {@link Kind#BAG, Kind#ORDERED_LIST} kinds: A previous 'continuation_position' returned by + * Windmill to signal the resulting bag was incomplete. Sending that position will request the + * next page of values. Null for first request. * * <p>Null for other kinds. */ - private final @Nullable Long requestPosition; + @Nullable + abstract RequestPositionT getRequestPosition(); - private StateTag( - Kind kind, ByteString tag, String stateFamily, @Nullable Long requestPosition) { - this.kind = kind; - this.tag = tag; - this.stateFamily = Preconditions.checkNotNull(stateFamily); - this.requestPosition = requestPosition; + /** For {@link Kind#ORDERED_LIST} kinds: the range to fetch or delete. */ + @Nullable + abstract Range<Long> getSortedListRange(); + + static <RequestPositionT> StateTag<RequestPositionT> of( + Kind kind, ByteString tag, String stateFamily, @Nullable RequestPositionT requestPosition) { + return new AutoValue_WindmillStateReader_StateTag.Builder<RequestPositionT>() + .setKind(kind) + .setTag(tag) + .setStateFamily(stateFamily) + .setRequestPosition(requestPosition) + .build(); } - private StateTag(Kind kind, ByteString tag, String stateFamily) { - this(kind, tag, stateFamily, null); + static <RequestPositionT> StateTag<RequestPositionT> of( + Kind kind, ByteString tag, String stateFamily) { + return of(kind, tag, stateFamily, null); } - @Override - public boolean equals(@Nullable Object obj) { - if (this == obj) { - return true; - } + abstract Builder<RequestPositionT> toBuilder(); - if (!(obj instanceof StateTag)) { - return false; - } + @AutoValue.Builder + abstract static class Builder<RequestPositionT> { + abstract Builder<RequestPositionT> setKind(Kind kind); - StateTag that = (StateTag) obj; - return Objects.equal(this.kind, that.kind) - && Objects.equal(this.tag, that.tag) - && Objects.equal(this.stateFamily, that.stateFamily) - && Objects.equal(this.requestPosition, that.requestPosition); - } + abstract Builder<RequestPositionT> setTag(ByteString tag); - @Override - public int hashCode() { - return Objects.hashCode(kind, tag, stateFamily, requestPosition); - } + abstract Builder<RequestPositionT> setStateFamily(String stateFamily); - @Override - public String toString() { - return "Tag(" - + kind - + "," - + tag.toStringUtf8() - + "," - + stateFamily - + (requestPosition == null ? "" : ("," + requestPosition.toString())) - + ")"; + abstract Builder<RequestPositionT> setRequestPosition( + @Nullable RequestPositionT requestPosition); + + abstract Builder<RequestPositionT> setSortedListRange(@Nullable Range<Long> sortedListRange); + + abstract StateTag<RequestPositionT> build(); } } @@ -148,13 +159,13 @@ class WindmillStateReader { * An in-memory collection of deserialized values and an optional continuation position to pass to * Windmill when fetching the next page of values. */ - private static class ValuesAndContPosition<T> { + private static class ValuesAndContPosition<T, ContinuationT> { private final List<T> values; /** Position to pass to next request for next page of values. Null if done. */ - private final @Nullable Long continuationPosition; + private final @Nullable ContinuationT continuationPosition; - public ValuesAndContPosition(List<T> values, @Nullable Long continuationPosition) { + public ValuesAndContPosition(List<T> values, @Nullable ContinuationT continuationPosition) { this.values = values; this.continuationPosition = continuationPosition; } @@ -218,13 +229,15 @@ class WindmillStateReader { } } - @VisibleForTesting ConcurrentLinkedQueue<StateTag> pendingLookups = new ConcurrentLinkedQueue<>(); - private ConcurrentHashMap<StateTag, CoderAndFuture<?, ?>> waiting = new ConcurrentHashMap<>(); + @VisibleForTesting + ConcurrentLinkedQueue<StateTag<?>> pendingLookups = new ConcurrentLinkedQueue<>(); + + private ConcurrentHashMap<StateTag<?>, CoderAndFuture<?, ?>> waiting = new ConcurrentHashMap<>(); private <ElemT, FutureT> Future<FutureT> stateFuture( - StateTag stateTag, @Nullable Coder<ElemT> coder) { + StateTag<?> stateTag, @Nullable Coder<ElemT> coder) { CoderAndFuture<ElemT, FutureT> coderAndFuture = - new CoderAndFuture<>(coder, SettableFuture.<FutureT>create()); + new CoderAndFuture<>(coder, SettableFuture.create()); CoderAndFuture<?, ?> existingCoderAndFutureWildcard = waiting.putIfAbsent(stateTag, coderAndFuture); if (existingCoderAndFutureWildcard == null) { @@ -242,7 +255,7 @@ class WindmillStateReader { } private <ElemT, FutureT> CoderAndFuture<ElemT, FutureT> getWaiting( - StateTag stateTag, boolean shouldRemove) { + StateTag<?> stateTag, boolean shouldRemove) { CoderAndFuture<?, ?> coderAndFutureWildcard; if (shouldRemove) { coderAndFutureWildcard = waiting.remove(stateTag); @@ -259,29 +272,41 @@ class WindmillStateReader { } public Future<Instant> watermarkFuture(ByteString encodedTag, String stateFamily) { - return stateFuture(new StateTag(StateTag.Kind.WATERMARK, encodedTag, stateFamily), null); + return stateFuture(StateTag.of(StateTag.Kind.WATERMARK, encodedTag, stateFamily), null); } public <T> Future<T> valueFuture(ByteString encodedTag, String stateFamily, Coder<T> coder) { - return stateFuture(new StateTag(StateTag.Kind.VALUE, encodedTag, stateFamily), coder); + return stateFuture(StateTag.of(StateTag.Kind.VALUE, encodedTag, stateFamily), coder); } public <T> Future<Iterable<T>> bagFuture( ByteString encodedTag, String stateFamily, Coder<T> elemCoder) { // First request has no continuation position. - StateTag stateTag = new StateTag(StateTag.Kind.BAG, encodedTag, stateFamily); + StateTag<Long> stateTag = StateTag.of(StateTag.Kind.BAG, encodedTag, stateFamily); // Convert the ValuesAndContPosition<T> to Iterable<T>. - return valuesToPagingIterableFuture( - stateTag, elemCoder, this.<T, ValuesAndContPosition<T>>stateFuture(stateTag, elemCoder)); + return valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)); + } + + public <T> Future<Iterable<TimestampedValue<T>>> orderedListFuture( + Range<Long> range, ByteString encodedTag, String stateFamily, Coder<T> elemCoder) { + // First request has no continuation position. + StateTag<ByteString> stateTag = + StateTag.<ByteString>of(StateTag.Kind.ORDERED_LIST, encodedTag, stateFamily) + .toBuilder() + .setSortedListRange(Preconditions.checkNotNull(range)) + .build(); + return Preconditions.checkNotNull( + valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder))); } /** - * Internal request to fetch the next 'page' of values in a TagBag. Return null if no continuation - * position is in {@code contStateTag}, which signals there are no more pages. + * Internal request to fetch the next 'page' of values. Return null if no continuation position is + * in {@code contStateTag}, which signals there are no more pages. */ - private @Nullable <T> Future<ValuesAndContPosition<T>> continuationBagFuture( - StateTag contStateTag, Coder<T> elemCoder) { - if (contStateTag.requestPosition == null) { + private @Nullable <ElemT, ContinuationT, ResultT> + Future<ValuesAndContPosition<ResultT, ContinuationT>> continuationFuture( + StateTag<ContinuationT> contStateTag, Coder<ElemT> elemCoder) { + if (contStateTag.getRequestPosition() == null) { // We're done. return null; } @@ -338,18 +363,19 @@ class WindmillStateReader { } /** Function to extract an {@link Iterable} from the continuation-supporting page read future. */ - private static class ToIterableFunction<T> - implements Function<ValuesAndContPosition<T>, Iterable<T>> { + private static class ToIterableFunction<ElemT, ContinuationT, ResultT> + implements Function<ValuesAndContPosition<ResultT, ContinuationT>, Iterable<ResultT>> { /** * Reader to request continuation pages from, or {@literal null} if no continuation pages * required. */ private @Nullable WindmillStateReader reader; - private final StateTag stateTag; - private final Coder<T> elemCoder; + private final StateTag<ContinuationT> stateTag; + private final Coder<ElemT> elemCoder; - public ToIterableFunction(WindmillStateReader reader, StateTag stateTag, Coder<T> elemCoder) { + public ToIterableFunction( + WindmillStateReader reader, StateTag<ContinuationT> stateTag, Coder<ElemT> elemCoder) { this.reader = reader; this.stateTag = stateTag; this.elemCoder = elemCoder; @@ -359,7 +385,8 @@ class WindmillStateReader { value = "NP_METHOD_PARAMETER_TIGHTENS_ANNOTATION", justification = "https://github.com/google/guava/issues/920") @Override - public Iterable<T> apply(@Nonnull ValuesAndContPosition<T> valuesAndContPosition) { + public Iterable<ResultT> apply( + @Nonnull ValuesAndContPosition<ResultT, ContinuationT> valuesAndContPosition) { if (valuesAndContPosition.continuationPosition == null) { // Number of values is small enough Windmill sent us the entire bag in one response. reader = null; @@ -367,12 +394,16 @@ class WindmillStateReader { } else { // Return an iterable which knows how to come back for more. StateTag contStateTag = - new StateTag( - stateTag.kind, - stateTag.tag, - stateTag.stateFamily, + StateTag.of( + stateTag.getKind(), + stateTag.getTag(), + stateTag.getStateFamily(), valuesAndContPosition.continuationPosition); - return new BagPagingIterable<>( + if (stateTag.getSortedListRange() != null) { + contStateTag = + contStateTag.toBuilder().setSortedListRange(stateTag.getSortedListRange()).build(); + } + return new PagingIterable<ElemT, ContinuationT, ResultT>( reader, valuesAndContPosition.values, contStateTag, elemCoder); } } @@ -382,18 +413,20 @@ class WindmillStateReader { * Return future which transforms a {@code ValuesAndContPosition<T>} result into the initial * Iterable<T> result expected from the external caller. */ - private <T> Future<Iterable<T>> valuesToPagingIterableFuture( - final StateTag stateTag, - final Coder<T> elemCoder, - final Future<ValuesAndContPosition<T>> future) { - return Futures.lazyTransform(future, new ToIterableFunction<T>(this, stateTag, elemCoder)); + private <ElemT, ResultT, ContinuationT> Future<Iterable<ResultT>> valuesToPagingIterableFuture( + final StateTag<ContinuationT> stateTag, + final Coder<ElemT> elemCoder, + final Future<ValuesAndContPosition<ResultT, ContinuationT>> future) { + Function<ValuesAndContPosition<ResultT, ContinuationT>, Iterable<ResultT>> toIterable = + new ToIterableFunction<>(this, stateTag, elemCoder); + return Futures.lazyTransform(future, toIterable); } public void startBatchAndBlock() { // First, drain work out of the pending lookups into a set. These will be the items we fetch. - HashSet<StateTag> toFetch = new HashSet<>(); + HashSet<StateTag<?>> toFetch = Sets.newHashSet(); while (!pendingLookups.isEmpty()) { - StateTag stateTag = pendingLookups.poll(); + StateTag<?> stateTag = pendingLookups.poll(); if (stateTag == null) { break; } @@ -411,7 +444,6 @@ class WindmillStateReader { Windmill.KeyedGetDataRequest request = createRequest(toFetch); Windmill.KeyedGetDataResponse response = server.getStateData(computation, request); - if (response == null) { throw new RuntimeException("Windmill unexpectedly returned null for request " + request); } @@ -423,47 +455,72 @@ class WindmillStateReader { return bytesRead; } - private Windmill.KeyedGetDataRequest createRequest(Iterable<StateTag> toFetch) { + private Windmill.KeyedGetDataRequest createRequest(Iterable<StateTag<?>> toFetch) { Windmill.KeyedGetDataRequest.Builder keyedDataBuilder = Windmill.KeyedGetDataRequest.newBuilder() .setKey(key) .setShardingKey(shardingKey) .setWorkToken(workToken); - for (StateTag stateTag : toFetch) { - switch (stateTag.kind) { + List<StateTag<?>> orderedListsToFetch = Lists.newArrayList(); + for (StateTag<?> stateTag : toFetch) { + switch (stateTag.getKind()) { case BAG: TagBag.Builder bag = keyedDataBuilder .addBagsToFetchBuilder() - .setTag(stateTag.tag) - .setStateFamily(stateTag.stateFamily) + .setTag(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()) .setFetchMaxBytes(MAX_BAG_BYTES); - if (stateTag.requestPosition != null) { + if (stateTag.getRequestPosition() != null) { // We're asking for the next page. - bag.setRequestPosition(stateTag.requestPosition); + bag.setRequestPosition((Long) stateTag.getRequestPosition()); } break; + case ORDERED_LIST: + orderedListsToFetch.add(stateTag); + break; + case WATERMARK: keyedDataBuilder .addWatermarkHoldsToFetchBuilder() - .setTag(stateTag.tag) - .setStateFamily(stateTag.stateFamily); + .setTag(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()); break; case VALUE: keyedDataBuilder .addValuesToFetchBuilder() - .setTag(stateTag.tag) - .setStateFamily(stateTag.stateFamily); + .setTag(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()); break; default: - throw new RuntimeException("Unknown kind of tag requested: " + stateTag.kind); + throw new RuntimeException("Unknown kind of tag requested: " + stateTag.getKind()); + } + } + orderedListsToFetch.sort( + Comparator.<StateTag<?>>comparingLong(s -> s.getSortedListRange().lowerEndpoint()) + .thenComparingLong(s -> s.getSortedListRange().upperEndpoint())); + for (StateTag<?> stateTag : orderedListsToFetch) { + Range<Long> range = Preconditions.checkNotNull(stateTag.getSortedListRange()); + TagSortedListFetchRequest.Builder sorted_list = + keyedDataBuilder + .addSortedListsToFetchBuilder() + .setTag(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()) + .setFetchMaxBytes(MAX_ORDERED_LIST_BYTES); + sorted_list.addFetchRanges( + SortedListRange.newBuilder() + .setStart(range.lowerEndpoint()) + .setLimit(range.upperEndpoint()) + .build()); + if (stateTag.getRequestPosition() != null) { + // We're asking for the next page. + sorted_list.setRequestPosition((ByteString) stateTag.getRequestPosition()); } } - keyedDataBuilder.setMaxBytes(MAX_KEY_BYTES); return keyedDataBuilder.build(); @@ -472,14 +529,14 @@ class WindmillStateReader { private void consumeResponse( Windmill.KeyedGetDataRequest request, Windmill.KeyedGetDataResponse response, - Set<StateTag> toFetch) { + Set<StateTag<?>> toFetch) { bytesRead += response.getSerializedSize(); if (response.getFailed()) { // Set up all the futures for this key to throw an exception: KeyTokenInvalidException keyTokenInvalidException = new KeyTokenInvalidException(key.toStringUtf8()); - for (StateTag stateTag : toFetch) { + for (StateTag<?> stateTag : toFetch) { waiting.get(stateTag).future.setException(keyTokenInvalidException); } return; @@ -490,8 +547,8 @@ class WindmillStateReader { } for (Windmill.TagBag bag : response.getBagsList()) { - StateTag stateTag = - new StateTag( + StateTag<Long> stateTag = + StateTag.of( StateTag.Kind.BAG, bag.getTag(), bag.getStateFamily(), @@ -504,8 +561,8 @@ class WindmillStateReader { } for (Windmill.WatermarkHold hold : response.getWatermarkHoldsList()) { - StateTag stateTag = - new StateTag(StateTag.Kind.WATERMARK, hold.getTag(), hold.getStateFamily()); + StateTag<Long> stateTag = + StateTag.of(StateTag.Kind.WATERMARK, hold.getTag(), hold.getStateFamily()); if (!toFetch.remove(stateTag)) { throw new IllegalStateException( "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch); @@ -514,13 +571,33 @@ class WindmillStateReader { } for (Windmill.TagValue value : response.getValuesList()) { - StateTag stateTag = new StateTag(StateTag.Kind.VALUE, value.getTag(), value.getStateFamily()); + StateTag<Long> stateTag = + StateTag.of(StateTag.Kind.VALUE, value.getTag(), value.getStateFamily()); if (!toFetch.remove(stateTag)) { throw new IllegalStateException( "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch); } consumeTagValue(value, stateTag); } + for (Windmill.TagSortedListFetchResponse sorted_list : response.getTagSortedListsList()) { + SortedListRange sortedListRange = Iterables.getOnlyElement(sorted_list.getFetchRangesList()); + Range<Long> range = Range.closedOpen(sortedListRange.getStart(), sortedListRange.getLimit()); + StateTag<ByteString> stateTag = + StateTag.of( + StateTag.Kind.ORDERED_LIST, + sorted_list.getTag(), + sorted_list.getStateFamily(), + sorted_list.hasRequestPosition() ? sorted_list.getRequestPosition() : null) + .toBuilder() + .setSortedListRange(range) + .build(); + if (!toFetch.remove(stateTag)) { + throw new IllegalStateException( + "Received response for unrequested tag " + stateTag + ". Pending tags: " + toFetch); + } + + consumeSortedList(sorted_list, stateTag); + } if (!toFetch.isEmpty()) { throw new IllegalStateException( @@ -577,9 +654,31 @@ class WindmillStateReader { return valueList; } - private <T> void consumeBag(TagBag bag, StateTag stateTag) { + private <T> List<TimestampedValue<T>> sortedListPageValues( + Windmill.TagSortedListFetchResponse sortedListFetchResponse, Coder<T> elemCoder) { + if (sortedListFetchResponse.getEntriesCount() == 0) { + return new WeightedList<>(Collections.emptyList()); + } + + WeightedList<TimestampedValue<T>> entryList = + new WeightedList<>(new ArrayList<>(sortedListFetchResponse.getEntriesCount())); + for (SortedListEntry entry : sortedListFetchResponse.getEntriesList()) { + try { + T value = elemCoder.decode(entry.getValue().newInput(), Coder.Context.OUTER); + entryList.addWeighted( + TimestampedValue.of( + value, WindmillTimeUtils.windmillToHarnessTimestamp(entry.getSortKey())), + entry.getValue().size() + 8); + } catch (IOException e) { + throw new IllegalStateException("Unable to decode tag sorted list using " + elemCoder, e); + } + } + return entryList; + } + + private <T> void consumeBag(TagBag bag, StateTag<Long> stateTag) { boolean shouldRemove; - if (stateTag.requestPosition == null) { + if (stateTag.getRequestPosition() == null) { // This is the response for the first page. // Leave the future in the cache so subsequent requests for the first page // can return immediately. @@ -590,16 +689,18 @@ class WindmillStateReader { // continuation positions. shouldRemove = true; } - CoderAndFuture<T, ValuesAndContPosition<T>> coderAndFuture = getWaiting(stateTag, shouldRemove); - SettableFuture<ValuesAndContPosition<T>> future = coderAndFuture.getNonDoneFuture(stateTag); + CoderAndFuture<T, ValuesAndContPosition<T, Long>> coderAndFuture = + getWaiting(stateTag, shouldRemove); + SettableFuture<ValuesAndContPosition<T, Long>> future = + coderAndFuture.getNonDoneFuture(stateTag); Coder<T> coder = coderAndFuture.getAndClearCoder(); - List<T> values = this.<T>bagPageValues(bag, coder); + List<T> values = this.bagPageValues(bag, coder); future.set( - new ValuesAndContPosition<T>( + new ValuesAndContPosition<>( values, bag.hasContinuationPosition() ? bag.getContinuationPosition() : null)); } - private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag stateTag) { + private void consumeWatermark(Windmill.WatermarkHold watermarkHold, StateTag<Long> stateTag) { CoderAndFuture<Void, Instant> coderAndFuture = getWaiting(stateTag, false); SettableFuture<Instant> future = coderAndFuture.getNonDoneFuture(stateTag); // No coders for watermarks @@ -619,7 +720,7 @@ class WindmillStateReader { future.set(hold); } - private <T> void consumeTagValue(TagValue tagValue, StateTag stateTag) { + private <T> void consumeTagValue(TagValue tagValue, StateTag<Long> stateTag) { CoderAndFuture<T, T> coderAndFuture = getWaiting(stateTag, false); SettableFuture<T> future = coderAndFuture.getNonDoneFuture(stateTag); Coder<T> coder = coderAndFuture.getAndClearCoder(); @@ -639,6 +740,35 @@ class WindmillStateReader { } } + private <T> void consumeSortedList( + Windmill.TagSortedListFetchResponse sortedListFetchResponse, StateTag<ByteString> stateTag) { + boolean shouldRemove; + if (stateTag.getRequestPosition() == null) { + // This is the response for the first page.// Leave the future in the cache so subsequent + // requests for the first page + // can return immediately. + shouldRemove = false; + } else { + // This is a response for a subsequent page. + // Don't cache the future since we may need to make multiple requests with different + // continuation positions. + shouldRemove = true; + } + + CoderAndFuture<T, ValuesAndContPosition<TimestampedValue<T>, ByteString>> coderAndFuture = + getWaiting(stateTag, shouldRemove); + SettableFuture<ValuesAndContPosition<TimestampedValue<T>, ByteString>> future = + coderAndFuture.getNonDoneFuture(stateTag); + Coder<T> coder = coderAndFuture.getAndClearCoder(); + List<TimestampedValue<T>> values = this.sortedListPageValues(sortedListFetchResponse, coder); + future.set( + new ValuesAndContPosition<>( + values, + sortedListFetchResponse.hasContinuationPosition() + ? sortedListFetchResponse.getContinuationPosition() + : null)); + } + /** * An iterable over elements backed by paginated GetData requests to Windmill. The iterable may be * iterated over an arbitrary number of times and multiple iterators may be active simultaneously. @@ -655,7 +785,7 @@ class WindmillStateReader { * call to iterator. * </ol> */ - private static class BagPagingIterable<T> implements Iterable<T> { + private static class PagingIterable<ElemT, ContinuationT, ResultT> implements Iterable<ResultT> { /** * The reader we will use for scheduling continuation pages. * @@ -664,16 +794,19 @@ class WindmillStateReader { private final WindmillStateReader reader; /** Initial values returned for the first page. Never reclaimed. */ - private final List<T> firstPage; + private final List<ResultT> firstPage; /** State tag with continuation position set for second page. */ - private final StateTag secondPagePos; + private final StateTag<ContinuationT> secondPagePos; /** Coder for elements. */ - private final Coder<T> elemCoder; + private final Coder<ElemT> elemCoder; - private BagPagingIterable( - WindmillStateReader reader, List<T> firstPage, StateTag secondPagePos, Coder<T> elemCoder) { + private PagingIterable( + WindmillStateReader reader, + List<ResultT> firstPage, + StateTag<ContinuationT> secondPagePos, + Coder<ElemT> elemCoder) { this.reader = reader; this.firstPage = firstPage; this.secondPagePos = secondPagePos; @@ -681,16 +814,16 @@ class WindmillStateReader { } @Override - public Iterator<T> iterator() { - return new AbstractIterator<T>() { - private Iterator<T> currentPage = firstPage.iterator(); - private StateTag nextPagePos = secondPagePos; - private Future<ValuesAndContPosition<T>> pendingNextPage = + public Iterator<ResultT> iterator() { + return new AbstractIterator<ResultT>() { + private Iterator<ResultT> currentPage = firstPage.iterator(); + private StateTag<ContinuationT> nextPagePos = secondPagePos; + private Future<ValuesAndContPosition<ResultT, ContinuationT>> pendingNextPage = // NOTE: The results of continuation page reads are never cached. - reader.continuationBagFuture(nextPagePos, elemCoder); + reader.continuationFuture(nextPagePos, elemCoder); @Override - protected T computeNext() { + protected ResultT computeNext() { while (true) { if (currentPage.hasNext()) { return currentPage.next(); @@ -699,7 +832,7 @@ class WindmillStateReader { return endOfData(); } - ValuesAndContPosition<T> valuesAndContPosition; + ValuesAndContPosition<ResultT, ContinuationT> valuesAndContPosition; try { valuesAndContPosition = pendingNextPage.get(); } catch (InterruptedException | ExecutionException e) { @@ -710,14 +843,14 @@ class WindmillStateReader { } currentPage = valuesAndContPosition.values.iterator(); nextPagePos = - new StateTag( - nextPagePos.kind, - nextPagePos.tag, - nextPagePos.stateFamily, + StateTag.of( + nextPagePos.getKind(), + nextPagePos.getTag(), + nextPagePos.getStateFamily(), valuesAndContPosition.continuationPosition); pendingNextPage = // NOTE: The results of continuation page reads are never cached. - reader.continuationBagFuture(nextPagePos, elemCoder); + reader.continuationFuture(nextPagePos, elemCoder); } } }; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java index 52bfcd4..367e1a8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java @@ -19,6 +19,7 @@ package org.apache.beam.runners.dataflow.worker; import static org.apache.beam.runners.dataflow.worker.DataflowMatchers.ByteStringMatcher.byteStringEq; import static org.apache.beam.sdk.testing.SystemNanoTimeSleeper.sleepMillis; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -27,17 +28,23 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.when; +import com.google.common.collect.Iterables; import java.io.Closeable; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.StateTag; import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.dataflow.options.DataflowWorkerHarnessOptions; +import org.apache.beam.runners.dataflow.worker.WindmillStateInternals.IdTracker; +import org.apache.beam.runners.dataflow.worker.WindmillStateInternals.WindmillOrderedList; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagBag; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListUpdateRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -46,15 +53,19 @@ import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.GroupingState; +import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Sum; import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CoderUtils; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Range; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.RangeSet; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Futures; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture; import org.hamcrest.Matchers; @@ -170,6 +181,341 @@ public class WindmillStateInternalsTest { return result; } + public static final Range<Long> FULL_ORDERED_LIST_RANGE = + Range.closedOpen(WindmillOrderedList.MIN_TS_MICROS, WindmillOrderedList.MAX_TS_MICROS); + + @Test + public void testOrderedListAddBeforeRead() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTest.state(NAMESPACE, addr); + + SettableFuture<Iterable<TimestampedValue<String>>> future = SettableFuture.create(); + when(mockReader.orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of())) + .thenReturn(future); + + orderedList.readLater(); + + final TimestampedValue<String> helloValue = + TimestampedValue.of("hello", Instant.ofEpochMilli(100)); + final TimestampedValue<String> worldValue = + TimestampedValue.of("world", Instant.ofEpochMilli(75)); + final TimestampedValue<String> goodbyeValue = + TimestampedValue.of("goodbye", Instant.ofEpochMilli(50)); + + orderedList.add(helloValue); + waitAndSet(future, Arrays.asList(worldValue), 200); + assertThat(orderedList.read(), Matchers.contains(worldValue, helloValue)); + + orderedList.add(goodbyeValue); + assertThat(orderedList.read(), Matchers.contains(goodbyeValue, worldValue, helloValue)); + } + + @Test + public void testOrderedListClearBeforeRead() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedListState = underTest.state(NAMESPACE, addr); + + final TimestampedValue<String> helloElement = TimestampedValue.of("hello", Instant.EPOCH); + orderedListState.clear(); + orderedListState.add(helloElement); + assertThat(orderedListState.read(), Matchers.containsInAnyOrder(helloElement)); + + // Shouldn't need to read from windmill for this. + Mockito.verifyZeroInteractions(mockReader); + } + + @Test + public void testOrderedListIsEmptyFalse() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTest.state(NAMESPACE, addr); + + SettableFuture<Iterable<TimestampedValue<String>>> future = SettableFuture.create(); + when(mockReader.orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of())) + .thenReturn(future); + ReadableState<Boolean> result = orderedList.isEmpty().readLater(); + Mockito.verify(mockReader) + .orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of()); + + waitAndSet(future, Arrays.asList(TimestampedValue.of("world", Instant.EPOCH)), 200); + assertThat(result.read(), Matchers.is(false)); + } + + @Test + public void testOrderedListIsEmptyTrue() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTest.state(NAMESPACE, addr); + + SettableFuture<Iterable<TimestampedValue<String>>> future = SettableFuture.create(); + when(mockReader.orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of())) + .thenReturn(future); + ReadableState<Boolean> result = orderedList.isEmpty().readLater(); + Mockito.verify(mockReader) + .orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of()); + + waitAndSet(future, Collections.emptyList(), 200); + assertThat(result.read(), Matchers.is(true)); + } + + @Test + public void testOrderedListIsEmptyAfterClear() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTest.state(NAMESPACE, addr); + + orderedList.clear(); + ReadableState<Boolean> result = orderedList.isEmpty(); + Mockito.verify(mockReader, never()) + .orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of()); + assertThat(result.read(), Matchers.is(true)); + + orderedList.add(TimestampedValue.of("hello", Instant.EPOCH)); + assertThat(result.read(), Matchers.is(false)); + } + + @Test + public void testOrderedListAddPersist() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTest.state(NAMESPACE, addr); + + SettableFuture<Map<Range<Instant>, RangeSet<Long>>> orderedListFuture = SettableFuture.create(); + orderedListFuture.set(null); + SettableFuture<Map<Range<Instant>, RangeSet<Instant>>> deletionsFuture = + SettableFuture.create(); + deletionsFuture.set(null); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.IDS_AVAILABLE_STR), + STATE_FAMILY, + IdTracker.IDS_AVAILABLE_CODER)) + .thenReturn(orderedListFuture); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.DELETIONS_STR), + STATE_FAMILY, + IdTracker.SUBRANGE_DELETIONS_CODER)) + .thenReturn(deletionsFuture); + + orderedList.add(TimestampedValue.of("hello", Instant.ofEpochMilli(1))); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getSortedListUpdatesCount()); + TagSortedListUpdateRequest updates = commitBuilder.getSortedListUpdates(0); + assertEquals(key(NAMESPACE, "orderedList"), updates.getTag()); + assertEquals(1, updates.getInsertsCount()); + assertEquals(1, updates.getInserts(0).getEntriesCount()); + + assertEquals("hello", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); + assertEquals(1000, updates.getInserts(0).getEntries(0).getSortKey()); + assertEquals(IdTracker.MIN_ID, updates.getInserts(0).getEntries(0).getId()); + } + + @Test + public void testOrderedListClearPersist() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedListState = underTest.state(NAMESPACE, addr); + + orderedListState.add(TimestampedValue.of("hello", Instant.ofEpochMilli(1))); + orderedListState.clear(); + orderedListState.add(TimestampedValue.of("world", Instant.ofEpochMilli(2))); + orderedListState.add(TimestampedValue.of("world", Instant.ofEpochMilli(2))); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getSortedListUpdatesCount()); + TagSortedListUpdateRequest updates = commitBuilder.getSortedListUpdates(0); + assertEquals(STATE_FAMILY, updates.getStateFamily()); + assertEquals(key(NAMESPACE, "orderedList"), updates.getTag()); + assertEquals(1, updates.getInsertsCount()); + assertEquals(2, updates.getInserts(0).getEntriesCount()); + + assertEquals("world", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); + assertEquals("world", updates.getInserts(0).getEntries(1).getValue().toStringUtf8()); + assertEquals(2000, updates.getInserts(0).getEntries(0).getSortKey()); + assertEquals(2000, updates.getInserts(0).getEntries(1).getSortKey()); + assertEquals(IdTracker.MIN_ID, updates.getInserts(0).getEntries(0).getId()); + assertEquals(IdTracker.MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); + Mockito.verifyNoMoreInteractions(mockReader); + } + + @Test + public void testOrderedListDeleteRangePersist() { + SettableFuture<Map<Range<Instant>, RangeSet<Long>>> orderedListFuture = SettableFuture.create(); + orderedListFuture.set(null); + SettableFuture<Map<Range<Instant>, RangeSet<Instant>>> deletionsFuture = + SettableFuture.create(); + deletionsFuture.set(null); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.IDS_AVAILABLE_STR), + STATE_FAMILY, + IdTracker.IDS_AVAILABLE_CODER)) + .thenReturn(orderedListFuture); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.DELETIONS_STR), + STATE_FAMILY, + IdTracker.SUBRANGE_DELETIONS_CODER)) + .thenReturn(deletionsFuture); + + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedListState = underTest.state(NAMESPACE, addr); + + orderedListState.add(TimestampedValue.of("hello", Instant.ofEpochMilli(1))); + orderedListState.add(TimestampedValue.of("hello", Instant.ofEpochMilli(2))); + orderedListState.add(TimestampedValue.of("hello", Instant.ofEpochMilli(2))); + orderedListState.add(TimestampedValue.of("world", Instant.ofEpochMilli(3))); + orderedListState.add(TimestampedValue.of("world", Instant.ofEpochMilli(4))); + orderedListState.clearRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)); + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getSortedListUpdatesCount()); + TagSortedListUpdateRequest updates = commitBuilder.getSortedListUpdates(0); + assertEquals(STATE_FAMILY, updates.getStateFamily()); + assertEquals(key(NAMESPACE, "orderedList"), updates.getTag()); + assertEquals(1, updates.getInsertsCount()); + assertEquals(2, updates.getInserts(0).getEntriesCount()); + + assertEquals("hello", updates.getInserts(0).getEntries(0).getValue().toStringUtf8()); + assertEquals("world", updates.getInserts(0).getEntries(1).getValue().toStringUtf8()); + assertEquals(1000, updates.getInserts(0).getEntries(0).getSortKey()); + assertEquals(4000, updates.getInserts(0).getEntries(1).getSortKey()); + assertEquals(IdTracker.MIN_ID, updates.getInserts(0).getEntries(0).getId()); + assertEquals(IdTracker.MIN_ID + 1, updates.getInserts(0).getEntries(1).getId()); + } + + @Test + public void testOrderedListMergePendingAdds() { + SettableFuture<Map<Range<Instant>, RangeSet<Long>>> orderedListFuture = SettableFuture.create(); + orderedListFuture.set(null); + SettableFuture<Map<Range<Instant>, RangeSet<Instant>>> deletionsFuture = + SettableFuture.create(); + deletionsFuture.set(null); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.IDS_AVAILABLE_STR), + STATE_FAMILY, + IdTracker.IDS_AVAILABLE_CODER)) + .thenReturn(orderedListFuture); + when(mockReader.valueFuture( + systemKey(NAMESPACE, "orderedList" + IdTracker.DELETIONS_STR), + STATE_FAMILY, + IdTracker.SUBRANGE_DELETIONS_CODER)) + .thenReturn(deletionsFuture); + + SettableFuture<Iterable<TimestampedValue<String>>> fromStorage = SettableFuture.create(); + when(mockReader.orderedListFuture( + FULL_ORDERED_LIST_RANGE, + key(NAMESPACE, "orderedList"), + STATE_FAMILY, + StringUtf8Coder.of())) + .thenReturn(fromStorage); + + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedListState = underTest.state(NAMESPACE, addr); + + orderedListState.add(TimestampedValue.of("second", Instant.ofEpochMilli(1))); + orderedListState.add(TimestampedValue.of("third", Instant.ofEpochMilli(2))); + orderedListState.add(TimestampedValue.of("fourth", Instant.ofEpochMilli(2))); + orderedListState.add(TimestampedValue.of("eighth", Instant.ofEpochMilli(10))); + orderedListState.add(TimestampedValue.of("ninth", Instant.ofEpochMilli(15))); + + fromStorage.set( + ImmutableList.of( + TimestampedValue.of("first", Instant.ofEpochMilli(-1)), + TimestampedValue.of("fifth", Instant.ofEpochMilli(5)), + TimestampedValue.of("sixth", Instant.ofEpochMilli(5)), + TimestampedValue.of("seventh", Instant.ofEpochMilli(5)), + TimestampedValue.of("tenth", Instant.ofEpochMilli(20)))); + + TimestampedValue[] expected = + Iterables.toArray( + ImmutableList.of( + TimestampedValue.of("first", Instant.ofEpochMilli(-1)), + TimestampedValue.of("second", Instant.ofEpochMilli(1)), + TimestampedValue.of("third", Instant.ofEpochMilli(2)), + TimestampedValue.of("fourth", Instant.ofEpochMilli(2)), + TimestampedValue.of("fifth", Instant.ofEpochMilli(5)), + TimestampedValue.of("sixth", Instant.ofEpochMilli(5)), + TimestampedValue.of("seventh", Instant.ofEpochMilli(5)), + TimestampedValue.of("eighth", Instant.ofEpochMilli(10)), + TimestampedValue.of("ninth", Instant.ofEpochMilli(15)), + TimestampedValue.of("tenth", Instant.ofEpochMilli(20))), + TimestampedValue.class); + + TimestampedValue[] read = Iterables.toArray(orderedListState.read(), TimestampedValue.class); + assertArrayEquals(expected, read); + } + + @Test + public void testOrderedListPersistEmpty() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedListState = underTest.state(NAMESPACE, addr); + + orderedListState.clear(); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + // 1 bag update = the clear + assertEquals(1, commitBuilder.getSortedListUpdatesCount()); + TagSortedListUpdateRequest updates = commitBuilder.getSortedListUpdates(0); + assertEquals(1, updates.getDeletesCount()); + assertEquals(WindmillOrderedList.MIN_TS_MICROS, updates.getDeletes(0).getRange().getStart()); + assertEquals(WindmillOrderedList.MAX_TS_MICROS, updates.getDeletes(0).getRange().getLimit()); + } + + @Test + public void testNewOrderedListNoFetch() throws Exception { + StateTag<OrderedListState<String>> addr = + StateTags.orderedList("orderedList", StringUtf8Coder.of()); + OrderedListState<String> orderedList = underTestNewKey.state(NAMESPACE, addr); + + assertThat(orderedList.read(), Matchers.emptyIterable()); + + // Shouldn't need to read from windmill for this. + Mockito.verifyZeroInteractions(mockReader); + } + + // test ordered list cleared before read + // test fetch + add + read + // test ids + @Test public void testBagAddBeforeRead() throws Exception { StateTag<BagState<String>> addr = StateTags.bag("bag", StringUtf8Coder.of()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java index f2628ff..fff4bc3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java @@ -26,11 +26,16 @@ import java.io.IOException; import java.util.concurrent.Future; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.KeyedGetDataRequest; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListEntry; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill.SortedListRange; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.VarIntCoder; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p26p0.com.google.protobuf.ByteString.Output; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Range; import org.hamcrest.Matchers; import org.joda.time.Instant; import org.junit.Before; @@ -199,6 +204,263 @@ public class WindmillStateReaderTest { } @Test + public void testReadSortedList() throws Exception { + long beginning = SortedListRange.getDefaultInstance().getStart(); + long end = SortedListRange.getDefaultInstance().getLimit(); + Future<Iterable<TimestampedValue<Integer>>> future = + underTest.orderedListFuture( + Range.closedOpen(beginning, end), STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + // Fetch the entire list. + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(beginning).setLimit(end)) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(5)).setSortKey(5000).setId(5)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8)) + .addFetchRanges( + SortedListRange.newBuilder().setStart(beginning).setLimit(end))); + + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + Iterable<TimestampedValue<Integer>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + for (TimestampedValue<Integer> unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat( + results, + Matchers.contains( + TimestampedValue.of(5, Instant.ofEpochMilli(5)), + TimestampedValue.of(6, Instant.ofEpochMilli(6)), + TimestampedValue.of(7, Instant.ofEpochMilli(7)), + TimestampedValue.of(8, Instant.ofEpochMilli(8)))); + assertNoReader(future); + } + + @Test + public void testReadSortedListRanges() throws Exception { + Future<Iterable<TimestampedValue<Integer>>> future1 = + underTest.orderedListFuture(Range.closedOpen(0L, 5L), STATE_KEY_1, STATE_FAMILY, INT_CODER); + Future<Iterable<TimestampedValue<Integer>>> future2 = + underTest.orderedListFuture(Range.closedOpen(5L, 6L), STATE_KEY_1, STATE_FAMILY, INT_CODER); + Future<Iterable<TimestampedValue<Integer>>> future3 = + underTest.orderedListFuture( + Range.closedOpen(6L, 10L), STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + // Fetch the entire list. + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(0).setLimit(5)) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6)) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10)) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(5)).setSortKey(5000).setId(5)) + .addFetchRanges(SortedListRange.newBuilder().setStart(0).setLimit(5))) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7)) + .addFetchRanges(SortedListRange.newBuilder().setStart(5).setLimit(6))) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8)) + .addFetchRanges(SortedListRange.newBuilder().setStart(6).setLimit(10))); + + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + { + Iterable<TimestampedValue<Integer>> results = future1.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + for (TimestampedValue<Integer> unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verifyNoMoreInteractions(mockWindmill); + assertThat(results, Matchers.contains(TimestampedValue.of(5, Instant.ofEpochMilli(5)))); + assertNoReader(future1); + } + + { + Iterable<TimestampedValue<Integer>> results = future2.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + for (TimestampedValue<Integer> unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verifyNoMoreInteractions(mockWindmill); + assertThat( + results, + Matchers.contains( + TimestampedValue.of(6, Instant.ofEpochMilli(6)), + TimestampedValue.of(7, Instant.ofEpochMilli(7)))); + assertNoReader(future2); + } + + { + Iterable<TimestampedValue<Integer>> results = future3.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + for (TimestampedValue<Integer> unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verifyNoMoreInteractions(mockWindmill); + assertThat(results, Matchers.contains(TimestampedValue.of(8, Instant.ofEpochMilli(8)))); + assertNoReader(future3); + } + } + + @Test + public void testReadSortedListWithContinuations() throws Exception { + long beginning = SortedListRange.getDefaultInstance().getStart(); + long end = SortedListRange.getDefaultInstance().getLimit(); + + Future<Iterable<TimestampedValue<Integer>>> future = + underTest.orderedListFuture( + Range.closedOpen(beginning, end), STATE_KEY_1, STATE_FAMILY, INT_CODER); + + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(beginning).setLimit(end)) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)); + + final ByteString CONT = ByteString.copyFrom("CONTINUATION", Charsets.UTF_8); + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(5)).setSortKey(5000).setId(5)) + .setContinuationPosition(CONT) + .addFetchRanges( + SortedListRange.newBuilder().setStart(beginning).setLimit(end))); + + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addSortedListsToFetch( + Windmill.TagSortedListFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addFetchRanges(SortedListRange.newBuilder().setStart(beginning).setLimit(end)) + .setRequestPosition(CONT) + .setFetchMaxBytes(WindmillStateReader.MAX_BAG_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagSortedLists( + Windmill.TagSortedListFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(6)).setSortKey(6000).setId(5)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(7)).setSortKey(7000).setId(7)) + .addEntries( + SortedListEntry.newBuilder().setValue(intData(8)).setSortKey(8000).setId(8)) + .addFetchRanges(SortedListRange.newBuilder().setStart(beginning).setLimit(end)) + .setRequestPosition(CONT)); + + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + + Iterable<TimestampedValue<Integer>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + for (TimestampedValue<Integer> unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat( + results, + Matchers.contains( + TimestampedValue.of(5, Instant.ofEpochMilli(5)), + TimestampedValue.of(6, Instant.ofEpochMilli(6)), + TimestampedValue.of(7, Instant.ofEpochMilli(7)), + TimestampedValue.of(8, Instant.ofEpochMilli(8)))); + // NOTE: The future will still contain a reference to the underlying reader. + } + + @Test public void testReadValue() throws Exception { Future<Integer> future = underTest.valueFuture(STATE_KEY_1, STATE_FAMILY, INT_CODER); Mockito.verifyNoMoreInteractions(mockWindmill); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index b90eaa7..b0e8bda 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -131,7 +131,7 @@ message SortedListRange { message TagSortedListFetchRequest { optional bytes tag = 1; optional string state_family = 2; - optional SortedListRange fetch_range = 3; + repeated SortedListRange fetch_ranges = 3; // Sets a limit on the maximum response value bytes optional int64 fetch_max_bytes = 5 [default = 0x7fffffffffffffff]; @@ -146,7 +146,11 @@ message TagSortedListFetchResponse { optional string state_family = 2; repeated SortedListEntry entries = 3; optional bytes continuation_position = 4; -} + // Fetch ranges copied from request. + repeated SortedListRange fetch_ranges = 5; + // Request position copied from request. + optional bytes request_position = 6; + } message TagSortedListUpdateRequest { optional bytes tag = 1; @@ -253,6 +257,7 @@ message KeyedGetDataRequest { optional fixed64 sharding_key = 6; repeated TagValue values_to_fetch = 3; repeated TagBag bags_to_fetch = 8; + // Must be at most one sorted_list_to_fetch for a given state family and tag. repeated TagSortedListFetchRequest sorted_lists_to_fetch = 9; repeated WatermarkHold watermark_holds_to_fetch = 5; @@ -282,6 +287,7 @@ message KeyedGetDataResponse { optional bool failed = 2; repeated TagValue values = 3; repeated TagBag bags = 6; + // There is one TagSortedListFetchResponse per state-family, tag pair. repeated TagSortedListFetchResponse tag_sorted_lists = 8; repeated WatermarkHold watermark_holds = 5; diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 8d19535..41d679b 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -2374,7 +2374,17 @@ public class ParDoTest implements Serializable { @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) - public void testOrderedListState() { + public void testOrderedListStateBounded() { + testOrderedListStateImpl(false); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) + public void testOrderedListStateUnbounded() { + testOrderedListStateImpl(true); + } + + void testOrderedListStateImpl(boolean unbounded) { final String stateId = "foo"; DoFn<KV<String, TimestampedValue<String>>, Iterable<TimestampedValue<String>>> fn = @@ -2408,6 +2418,7 @@ public class ParDoTest implements Serializable { KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(42))), KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(52))), KV.of("hello", TimestampedValue.of("c", Instant.ofEpochMilli(12))))) + .setIsBoundedInternal(unbounded ? IsBounded.UNBOUNDED : IsBounded.BOUNDED) .apply(ParDo.of(fn)); List<TimestampedValue<String>> expected = @@ -2423,7 +2434,17 @@ public class ParDoTest implements Serializable { @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) - public void testOrderedListStateRangeFetch() { + public void testOrderedListStateRangeFetchBounded() { + testOrderedListStateRangeFetchImpl(false); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) + public void testOrderedListStateRangeFetchUnbounded() { + testOrderedListStateRangeFetchImpl(true); + } + + void testOrderedListStateRangeFetchImpl(boolean unbounded) { final String stateId = "foo"; DoFn<KV<String, TimestampedValue<String>>, Iterable<TimestampedValue<String>>> fn = @@ -2459,6 +2480,7 @@ public class ParDoTest implements Serializable { KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(42))), KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(52))), KV.of("hello", TimestampedValue.of("c", Instant.ofEpochMilli(12))))) + .setIsBoundedInternal(unbounded ? IsBounded.UNBOUNDED : IsBounded.BOUNDED) .apply(ParDo.of(fn)); List<TimestampedValue<String>> expected1 = Lists.newArrayList(); @@ -2482,7 +2504,17 @@ public class ParDoTest implements Serializable { @Test @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) - public void testOrderedListStateRangeDelete() { + public void testOrderedListStateRangeDeleteBounded() { + testOrderedListStateRangeDeleteImpl(false); + } + + @Test + @Category({ValidatesRunner.class, UsesStatefulParDo.class, UsesOrderedListState.class}) + public void testOrderedListStateRangeDeleteUnbounded() { + testOrderedListStateRangeDeleteImpl(true); + } + + void testOrderedListStateRangeDeleteImpl(boolean unbounded) { final String stateId = "foo"; DoFn<KV<String, TimestampedValue<String>>, Iterable<TimestampedValue<String>>> fn = new DoFn<KV<String, TimestampedValue<String>>, Iterable<TimestampedValue<String>>>() { @@ -2525,6 +2557,7 @@ public class ParDoTest implements Serializable { KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(42))), KV.of("hello", TimestampedValue.of("b", Instant.ofEpochMilli(52))), KV.of("hello", TimestampedValue.of("c", Instant.ofEpochMilli(12))))) + .setIsBoundedInternal(unbounded ? IsBounded.UNBOUNDED : IsBounded.BOUNDED) .apply(ParDo.of(fn)); List<TimestampedValue<String>> expected =