This is an automated email from the ASF dual-hosted git repository. dcapwell pushed a commit to branch trunk in repository https://gitbox.apache.org/repos/asf/cassandra-accord.git
The following commit(s) were added to refs/heads/trunk by this push: new 08aaab6e CEP-15 (Accord) Original and recover coordinators may hit a race condition with PreApply where reads and writes are interleaved, causing one of the coordinators to see the writes from the other 08aaab6e is described below commit 08aaab6e33d43406e0649146144e4df67648602a Author: David Capwell <dcapw...@gmail.com> AuthorDate: Fri Apr 7 15:33:46 2023 -0700 CEP-15 (Accord) Original and recover coordinators may hit a race condition with PreApply where reads and writes are interleaved, causing one of the coordinators to see the writes from the other patch by David Capwell; reviewed by Ariel Weisberg for CASSANDRA-18422 --- .../java/accord/impl/InMemoryCommandStore.java | 29 ++ .../src/main/java/accord/local/CommandStore.java | 10 +- .../src/main/java/accord/local/Commands.java | 1 + .../src/main/java/accord/messages/ReadData.java | 90 +++++- .../java/accord/utils/async/AsyncExecutor.java | 38 +++ .../main/java/accord/utils/async/AsyncResult.java | 6 +- accord-core/src/test/java/accord/Utils.java | 31 +++ .../src/test/java/accord/burn/BurnTest.java | 42 +-- .../basic/SimulatedDelayedExecutorService.java | 82 +----- ...ecutorService.java => TaskExecutorService.java} | 119 +++++--- .../src/test/java/accord/impl/list/ListRead.java | 61 ++-- .../src/test/java/accord/impl/list/ListUpdate.java | 20 +- .../src/test/java/accord/impl/list/ListWrite.java | 27 +- .../test/java/accord/local/CheckedCommands.java | 60 ++++ .../test/java/accord/messages/PreAcceptTest.java | 29 +- .../test/java/accord/messages/ReadDataTest.java | 306 +++++++++++++++++++++ .../verify/StrictSerializabilityVerifier.java | 2 + .../src/main/groovy/accord.java-conventions.gradle | 1 + 18 files changed, 748 insertions(+), 206 deletions(-) diff --git a/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java b/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java index e81186cc..ff34236d 100644 --- a/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java +++ b/accord-core/src/main/java/accord/impl/InMemoryCommandStore.java @@ -808,6 +808,29 @@ public abstract class InMemoryCommandStore implements CommandStore }; } + @Override + public <T> AsyncChain<T> submit(Callable<T> task) + { + return new AsyncChains.Head<T>() + { + @Override + protected void start(BiConsumer<? super T, Throwable> callback) + { + enqueueAndRun(() -> { + try + { + callback.accept(task.call(), null); + } + catch (Throwable t) + { + logger.error("Uncaught exception", t); + callback.accept(null, t); + } + }); + } + }; + } + @Override public void shutdown() {} } @@ -864,6 +887,12 @@ public abstract class InMemoryCommandStore implements CommandStore return AsyncChains.ofCallable(executor, () -> executeInContext(this, context, function)); } + @Override + public <T> AsyncChain<T> submit(Callable<T> task) + { + return AsyncChains.ofCallable(executor, task); + } + @Override public void shutdown() { diff --git a/accord-core/src/main/java/accord/local/CommandStore.java b/accord-core/src/main/java/accord/local/CommandStore.java index 65f8949c..479f5817 100644 --- a/accord-core/src/main/java/accord/local/CommandStore.java +++ b/accord-core/src/main/java/accord/local/CommandStore.java @@ -23,6 +23,7 @@ import accord.api.ProgressLog; import accord.api.DataStore; import accord.local.CommandStores.RangesForEpochHolder; import accord.utils.async.AsyncChain; +import accord.utils.async.AsyncExecutor; import java.util.function.Consumer; import java.util.function.Function; @@ -30,7 +31,7 @@ import java.util.function.Function; /** * Single threaded internal shard of accord transaction metadata */ -public interface CommandStore +public interface CommandStore extends AsyncExecutor { interface Factory { @@ -46,5 +47,12 @@ public interface CommandStore Agent agent(); AsyncChain<Void> execute(PreLoadContext context, Consumer<? super SafeCommandStore> consumer); <T> AsyncChain<T> submit(PreLoadContext context, Function<? super SafeCommandStore, T> apply); + + @Override + default void execute(Runnable command) + { + submit(command).begin(agent()); + } + void shutdown(); } diff --git a/accord-core/src/main/java/accord/local/Commands.java b/accord-core/src/main/java/accord/local/Commands.java index e1244622..66546a40 100644 --- a/accord-core/src/main/java/accord/local/Commands.java +++ b/accord-core/src/main/java/accord/local/Commands.java @@ -402,6 +402,7 @@ public class Commands attrs = set(safeStore, command, attrs, coordinateRanges, executeRanges, shard, route, null, Check, partialDeps, command.hasBeen(Committed) ? Add : TrySet); safeCommand.preapplied(attrs, executeAt, waitingOn, writes, result); + safeStore.notifyListeners(safeCommand); logger.trace("{}: apply, status set to Executed with executeAt: {}, deps: {}", txnId, executeAt, partialDeps); maybeExecute(safeStore, safeCommand, shard, true, true); diff --git a/accord-core/src/main/java/accord/messages/ReadData.java b/accord-core/src/main/java/accord/messages/ReadData.java index a4ebb4f3..d7558d35 100644 --- a/accord-core/src/main/java/accord/messages/ReadData.java +++ b/accord-core/src/main/java/accord/messages/ReadData.java @@ -29,6 +29,7 @@ import accord.local.*; import accord.api.Data; import accord.topology.Topologies; import accord.utils.Invariants; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -53,12 +54,41 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements return new ReadData(txnId, scope, executeAtEpoch, waitForEpoch); } } + private class ObsoleteTracker implements CommandListener + { + @Override + public void onChange(SafeCommandStore safeStore, SafeCommand safeCommand) + { + switch (safeCommand.current().status()) + { + case PreApplied: + case Applied: + case Invalidated: + obsolete(); + safeCommand.removeListener(this); + } + } + + @Override + public PreLoadContext listenerPreLoadContext(TxnId caller) + { + return ReadData.this.listenerPreLoadContext(caller); + } + + @Override + public boolean isTransient() + { + return true; + } + } + private final ObsoleteTracker obsoleteTracker = new ObsoleteTracker(); public final long executeAtEpoch; public final Seekables<?, ?> readScope; // TODO (low priority, efficiency): this should be RoutingKeys, as we have the Keys locally, but for simplicity we use this to implement keys() private final long waitForEpoch; private Data data; - private transient boolean isObsolete; // TODO (low priority, semantics): respond with the Executed result we have stored? + private enum State { PENDING, RETURNED, OBSOLETE } + private transient State state = State.PENDING; // TODO (low priority, semantics): respond with the Executed result we have stored? private transient BitSet waitingOn; private transient int waitingOnCount; @@ -131,10 +161,8 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements case ReadyToExecute: } - command = safeCommand.removeListener(this); - - if (!isObsolete) - read(safeStore, command.asCommitted()); + safeCommand.removeListener(this); + maybeRead(safeStore, safeCommand); } @Override @@ -145,7 +173,7 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements logger.trace("{}: setting up read with status {} on {}", txnId, status, safeStore); switch (status) { default: - throw new AssertionError(); + throw new AssertionError("Unknown status: " + status); case Committed: case NotWitnessed: case PreAccepted: @@ -166,18 +194,34 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements case ReadyToExecute: waitingOn.set(safeStore.commandStore().id()); ++waitingOnCount; - if (!isObsolete) - read(safeStore, safeCommand.current().asCommitted()); + maybeRead(safeStore, safeCommand); return null; case PreApplied: case Applied: case Invalidated: - isObsolete = true; + state = State.OBSOLETE; return Redundant; } } + private void maybeRead(SafeCommandStore safeStore, SafeCommand safeCommand) + { + switch (state) + { + case PENDING: + read(safeStore, safeCommand, safeCommand.current().asCommitted()); + break; + case OBSOLETE: + // nothing to see here + break; + case RETURNED: + throw new IllegalStateException("ReadOk was sent, yet ack called again"); + default: + throw new AssertionError("Unknown state: " + state); + } + } + @Override public ReadNack reduce(ReadNack r1, ReadNack r2) { @@ -219,12 +263,27 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements // and prevents races where we respond before dispatching all the required reads (if the reads are // completing faster than the reads can be setup on all required shards) if (-1 == --waitingOnCount) - node.reply(replyTo, replyContext, new ReadOk(data)); + { + switch (state) + { + case RETURNED: + throw new IllegalStateException("ReadOk was sent, yet ack called again"); + case OBSOLETE: + logger.debug("After the read completed for txn {}, the result was marked obsolete", txnId); + break; + case PENDING: + state = State.RETURNED; + node.reply(replyTo, replyContext, new ReadOk(data)); + break; + default: + throw new AssertionError("Unknown state: " + state); + } + } } private synchronized void readComplete(CommandStore commandStore, Data result) { - Invariants.checkState(waitingOn.get(commandStore.id()), "Waiting on does not contain store %d; waitingOn=%s", commandStore.id(), waitingOn); + Invariants.checkState(waitingOn.get(commandStore.id()), "Txn %s's waiting on does not contain store %d; waitingOn=%s", txnId, commandStore.id(), waitingOn); logger.trace("{}: read completed on {}", txnId, commandStore); if (result != null) data = data == null ? result : data.merge(result); @@ -233,8 +292,9 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements ack(); } - private void read(SafeCommandStore safeStore, Command.Committed command) + private void read(SafeCommandStore safeStore, SafeCommand safeCommand, Command.Committed command) { + safeCommand.addListener(obsoleteTracker); CommandStore unsafeStore = safeStore.commandStore(); logger.trace("{}: executing read", command.txnId()); command.read(safeStore).begin((next, throwable) -> { @@ -249,11 +309,11 @@ public class ReadData extends AbstractEpochRequest<ReadData.ReadNack> implements }); } - void obsolete() + synchronized void obsolete() { - if (!isObsolete) + if (state == State.PENDING) { - isObsolete = true; + state = State.OBSOLETE; node.reply(replyTo, replyContext, Redundant); } } diff --git a/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java b/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java new file mode 100644 index 00000000..50d2116e --- /dev/null +++ b/accord-core/src/main/java/accord/utils/async/AsyncExecutor.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package accord.utils.async; + +import java.util.concurrent.Callable; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +public interface AsyncExecutor extends Executor +{ + default AsyncChain<?> submit(Runnable task) + { + return submit(Executors.callable(task)); + } + + default <T> AsyncChain<T> submit(Runnable task, T result) + { + return submit(Executors.callable(task, result)); + } + + <T> AsyncChain<T> submit(Callable<T> task); +} diff --git a/accord-core/src/main/java/accord/utils/async/AsyncResult.java b/accord-core/src/main/java/accord/utils/async/AsyncResult.java index 59f97b7b..3269f7ac 100644 --- a/accord-core/src/main/java/accord/utils/async/AsyncResult.java +++ b/accord-core/src/main/java/accord/utils/async/AsyncResult.java @@ -79,7 +79,11 @@ public interface AsyncResult<V> extends AsyncChain<V> default void setFailure(Throwable throwable) { if (!tryFailure(throwable)) - throw new IllegalStateException("Result has already been set on " + this); + { + IllegalStateException e = new IllegalStateException("Result has already been set on " + this); + e.addSuppressed(throwable); + throw e; + } } default BiConsumer<V, Throwable> settingCallback() diff --git a/accord-core/src/test/java/accord/Utils.java b/accord-core/src/test/java/accord/Utils.java index 42dab7d3..ce5a80c8 100644 --- a/accord-core/src/test/java/accord/Utils.java +++ b/accord-core/src/test/java/accord/Utils.java @@ -18,7 +18,16 @@ package accord; +import accord.api.MessageSink; +import accord.api.Scheduler; +import accord.impl.InMemoryCommandStores; +import accord.impl.IntKey; +import accord.impl.SimpleProgressLog; import accord.impl.SizeOfIntersectionSorter; +import accord.impl.TestAgent; +import accord.impl.mock.MockCluster; +import accord.impl.mock.MockConfigurationService; +import accord.local.ShardDistributor; import accord.primitives.Range; import accord.local.Node; import accord.impl.mock.MockStore; @@ -28,7 +37,11 @@ import accord.topology.Topologies; import accord.topology.Topology; import accord.primitives.Txn; import accord.primitives.Keys; +import accord.utils.DefaultRandom; +import accord.utils.EpochFunction; import accord.utils.Invariants; +import accord.utils.ThreadPoolScheduler; + import com.google.common.collect.Sets; import java.util.ArrayList; @@ -117,4 +130,22 @@ public class Utils { return new Topologies.Multi(SizeOfIntersectionSorter.SUPPLIER, topologies); } + + public static Node createNode(Node.Id nodeId, Topology topology, MessageSink messageSink, MockCluster.Clock clock) + { + MockStore store = new MockStore(); + Scheduler scheduler = new ThreadPoolScheduler(); + return new Node(nodeId, + messageSink, + new MockConfigurationService(messageSink, EpochFunction.noop(), topology), + clock, + () -> store, + new ShardDistributor.EvenSplit(8, ignore -> new IntKey.Splitter()), + new TestAgent(), + new DefaultRandom(), + scheduler, + SizeOfIntersectionSorter.SUPPLIER, + SimpleProgressLog::new, + InMemoryCommandStores.Synchronized::new); + } } diff --git a/accord-core/src/test/java/accord/burn/BurnTest.java b/accord-core/src/test/java/accord/burn/BurnTest.java index 18a923d8..78269a79 100644 --- a/accord-core/src/test/java/accord/burn/BurnTest.java +++ b/accord-core/src/test/java/accord/burn/BurnTest.java @@ -32,32 +32,39 @@ import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.LongSupplier; import java.util.function.Predicate; -import accord.utils.DefaultRandom; -import accord.utils.RandomSource; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import accord.api.Key; import accord.impl.IntHashKey; -import accord.impl.basic.Cluster; -import accord.impl.basic.PropagatingPendingQueue; -import accord.impl.basic.RandomDelayQueue.Factory; import accord.impl.TopologyFactory; +import accord.impl.basic.Cluster; import accord.impl.basic.Packet; import accord.impl.basic.PendingQueue; +import accord.impl.basic.PropagatingPendingQueue; +import accord.impl.basic.RandomDelayQueue.Factory; +import accord.impl.basic.SimulatedDelayedExecutorService; import accord.impl.list.ListQuery; import accord.impl.list.ListRead; import accord.impl.list.ListRequest; import accord.impl.list.ListResult; import accord.impl.list.ListUpdate; +import accord.local.CommandStore; import accord.local.Node.Id; -import accord.api.Key; -import accord.primitives.*; +import accord.primitives.Keys; +import accord.primitives.Range; +import accord.primitives.Ranges; +import accord.primitives.Txn; +import accord.utils.DefaultRandom; +import accord.utils.RandomSource; +import accord.utils.async.AsyncExecutor; import accord.verify.StrictSerializabilityVerifier; -import org.junit.jupiter.api.Test; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import static accord.impl.IntHashKey.forHash; import static accord.utils.Utils.toArray; @@ -65,7 +72,7 @@ public class BurnTest { private static final Logger logger = LoggerFactory.getLogger(BurnTest.class); - static List<Packet> generate(RandomSource random, List<Id> clients, List<Id> nodes, int keyCount, int operations) + static List<Packet> generate(RandomSource random, Function<? super CommandStore, AsyncExecutor> executor, List<Id> clients, List<Id> nodes, int keyCount, int operations) { List<Key> keys = new ArrayList<>(); for (int i = 0 ; i < keyCount ; ++i) @@ -73,6 +80,7 @@ public class BurnTest List<Packet> packets = new ArrayList<>(); int[] next = new int[keyCount]; + double readInCommandStore = random.nextDouble(); for (int count = 0 ; count < operations ; ++count) { @@ -90,7 +98,7 @@ public class BurnTest requestRanges.add(IntHashKey.range(forHash(i), forHash(j))); } Ranges ranges = Ranges.of(requestRanges.toArray(new Range[0])); - ListRead read = new ListRead(ranges, ranges); + ListRead read = new ListRead(random.decide(readInCommandStore) ? Function.identity() : executor, ranges, ranges); ListQuery query = new ListQuery(client, count); ListRequest request = new ListRequest(new Txn.InMemory(ranges, read, query, null)); packets.add(new Packet(client, node, count, request)); @@ -107,7 +115,7 @@ public class BurnTest while (readCount-- > 0) requestKeys.add(randomKey(random, keys, requestKeys)); - ListUpdate update = isWrite ? new ListUpdate() : null; + ListUpdate update = isWrite ? new ListUpdate(executor) : null; while (writeCount-- > 0) { int i = randomKeyIndex(random, keys, update.keySet()); @@ -117,7 +125,7 @@ public class BurnTest Keys readKeys = new Keys(requestKeys); if (isWrite) requestKeys.addAll(update.keySet()); - ListRead read = new ListRead(readKeys, new Keys(requestKeys)); + ListRead read = new ListRead(random.decide(readInCommandStore) ? Function.identity() : executor, readKeys, new Keys(requestKeys)); ListQuery query = new ListQuery(client, count); ListRequest request = new ListRequest(new Txn.InMemory(new Keys(requestKeys), read, query, update)); packets.add(new Packet(client, node, count, request)); @@ -191,10 +199,12 @@ public class BurnTest { List<Throwable> failures = Collections.synchronizedList(new ArrayList<>()); PendingQueue queue = new PropagatingPendingQueue(failures, new Factory(random).get()); + SimulatedDelayedExecutorService globalExecutor = new SimulatedDelayedExecutorService(queue, random.fork()); StrictSerializabilityVerifier strictSerializable = new StrictSerializabilityVerifier(keyCount); + Function<CommandStore, AsyncExecutor> executor = ignore -> globalExecutor; - Packet[] requests = toArray(generate(random, clients, nodes, keyCount, operations), Packet[]::new); + Packet[] requests = toArray(generate(random, executor, clients, nodes, keyCount, operations), Packet[]::new); int[] starts = new int[requests.length]; Packet[] replies = new Packet[requests.length]; diff --git a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java b/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java index e400d085..850f717d 100644 --- a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java +++ b/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java @@ -18,20 +18,15 @@ package accord.impl.basic; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + import accord.burn.random.FrequentLargeRange; import accord.burn.random.RandomLong; import accord.burn.random.RandomWalkRange; import accord.utils.RandomSource; -import java.util.Collections; -import java.util.List; -import java.util.concurrent.AbstractExecutorService; -import java.util.concurrent.Callable; -import java.util.concurrent.Executors; -import java.util.concurrent.FutureTask; -import java.util.concurrent.TimeUnit; - -public class SimulatedDelayedExecutorService extends AbstractExecutorService +public class SimulatedDelayedExecutorService extends TaskExecutorService { private final PendingQueue pending; private final RandomSource random; @@ -44,10 +39,9 @@ public class SimulatedDelayedExecutorService extends AbstractExecutorService // this is different from Apache Cassandra Simulator as this is computed differently for each executor // rather than being a global config double ratio = random.nextInt(1, 11) / 100.0D; - this.jitterInNano = new FrequentLargeRange( - new RandomWalkRange(random, microToNanos(0), microToNanos(50)), - new RandomWalkRange(random, microToNanos(50), msToNanos(5)), - ratio); + this.jitterInNano = new FrequentLargeRange(new RandomWalkRange(random, microToNanos(0), microToNanos(50)), + new RandomWalkRange(random, microToNanos(50), msToNanos(5)), + ratio); } private static int msToNanos(int value) @@ -61,63 +55,15 @@ public class SimulatedDelayedExecutorService extends AbstractExecutorService } @Override - protected <T> Task<T> newTaskFor(Runnable runnable, T value) - { - return newTaskFor(Executors.callable(runnable, value)); - } - - @Override - protected <T> Task<T> newTaskFor(Callable<T> callable) - { - return new Task<>(callable); - } - - private Task<?> newTaskFor(Runnable command) - { - return command instanceof Task ? (Task<?>) command : newTaskFor(command, null); - } - - @Override - public void execute(Runnable command) - { - pending.add(newTaskFor(command), jitterInNano.getLong(random), TimeUnit.NANOSECONDS); - } - - @Override - public void shutdown() - { - } - - @Override - public List<Runnable> shutdownNow() - { - return Collections.emptyList(); - } - - @Override - public boolean isShutdown() + public void execute(Task<?> task) { - return false; + pending.add(task, jitterInNano.getLong(random), TimeUnit.NANOSECONDS); } - @Override - public boolean isTerminated() - { - return false; - } - - @Override - public boolean awaitTermination(long timeout, TimeUnit unit) - { - return false; - } - - - private static class Task<T> extends FutureTask<T> implements Pending + public <T> Task<T> submit(Callable<T> fn, long delay, TimeUnit unit) { - public Task(Callable<T> fn) - { - super(fn); - } + Task<T> task = newTaskFor(fn); + pending.add(task, jitterInNano.getLong(random) + unit.toNanos(delay), TimeUnit.NANOSECONDS); + return task; } -} +} \ No newline at end of file diff --git a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java b/accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java similarity index 51% copy from accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java copy to accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java index e400d085..2ab766bb 100644 --- a/accord-core/src/test/java/accord/impl/basic/SimulatedDelayedExecutorService.java +++ b/accord-core/src/test/java/accord/impl/basic/TaskExecutorService.java @@ -18,46 +18,65 @@ package accord.impl.basic; -import accord.burn.random.FrequentLargeRange; -import accord.burn.random.RandomLong; -import accord.burn.random.RandomWalkRange; -import accord.utils.RandomSource; - -import java.util.Collections; import java.util.List; import java.util.concurrent.AbstractExecutorService; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; -import java.util.concurrent.FutureTask; +import java.util.concurrent.RunnableFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; -public class SimulatedDelayedExecutorService extends AbstractExecutorService -{ - private final PendingQueue pending; - private final RandomSource random; - private final RandomLong jitterInNano; +import accord.utils.async.AsyncExecutor; +import accord.utils.async.AsyncResults; - public SimulatedDelayedExecutorService(PendingQueue pending, RandomSource random) +public abstract class TaskExecutorService extends AbstractExecutorService implements AsyncExecutor +{ + public static class Task<T> extends AsyncResults.SettableResult<T> implements Pending, RunnableFuture<T> { - this.pending = pending; - this.random = random; - // this is different from Apache Cassandra Simulator as this is computed differently for each executor - // rather than being a global config - double ratio = random.nextInt(1, 11) / 100.0D; - this.jitterInNano = new FrequentLargeRange( - new RandomWalkRange(random, microToNanos(0), microToNanos(50)), - new RandomWalkRange(random, microToNanos(50), msToNanos(5)), - ratio); - } + private final Callable<T> fn; - private static int msToNanos(int value) - { - return Math.toIntExact(TimeUnit.MILLISECONDS.toNanos(value)); - } + public Task(Callable<T> fn) + { + this.fn = fn; + } - private static int microToNanos(int value) - { - return Math.toIntExact(TimeUnit.MICROSECONDS.toNanos(value)); + @Override + public void run() + { + try + { + setSuccess(fn.call()); + } + catch (Throwable t) + { + setFailure(t); + } + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) + { + return false; + } + + @Override + public boolean isCancelled() + { + return false; + } + + @Override + public T get() throws InterruptedException, ExecutionException + { + throw new UnsupportedOperationException(); + } + + @Override + public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException + { + throw new UnsupportedOperationException(); + } } @Override @@ -77,21 +96,42 @@ public class SimulatedDelayedExecutorService extends AbstractExecutorService return command instanceof Task ? (Task<?>) command : newTaskFor(command, null); } + protected abstract void execute(Task<?> task); + + @Override + public final void execute(Runnable command) + { + execute(newTaskFor(command)); + } + + @Override + public Task<?> submit(Runnable task) + { + return (Task<?>) super.submit(task); + } + + @Override + public <T> Task<T> submit(Runnable task, T result) + { + return (Task<T>) super.submit(task, result); + } + @Override - public void execute(Runnable command) + public <T> Task<T> submit(Callable<T> task) { - pending.add(newTaskFor(command), jitterInNano.getLong(random), TimeUnit.NANOSECONDS); + return (Task<T>) super.submit(task); } @Override public void shutdown() { + } @Override public List<Runnable> shutdownNow() { - return Collections.emptyList(); + return null; } @Override @@ -107,17 +147,8 @@ public class SimulatedDelayedExecutorService extends AbstractExecutorService } @Override - public boolean awaitTermination(long timeout, TimeUnit unit) + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { return false; } - - - private static class Task<T> extends FutureTask<T> implements Pending - { - public Task(Callable<T> fn) - { - super(fn); - } - } -} +} \ No newline at end of file diff --git a/accord-core/src/test/java/accord/impl/list/ListRead.java b/accord-core/src/test/java/accord/impl/list/ListRead.java index c8ea05cc..67435b9e 100644 --- a/accord-core/src/test/java/accord/impl/list/ListRead.java +++ b/accord-core/src/test/java/accord/impl/list/ListRead.java @@ -18,29 +18,38 @@ package accord.impl.list; -import accord.api.*; +import java.util.Map; +import java.util.function.Function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import accord.api.Data; +import accord.api.DataStore; +import accord.api.Key; +import accord.api.Read; +import accord.local.CommandStore; import accord.local.SafeCommandStore; -import accord.primitives.*; +import accord.primitives.Range; import accord.primitives.Ranges; -import accord.primitives.Keys; +import accord.primitives.Seekable; +import accord.primitives.Seekables; import accord.primitives.Timestamp; import accord.primitives.Txn; import accord.utils.async.AsyncChain; -import accord.utils.async.AsyncChains; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.Map; +import accord.utils.async.AsyncExecutor; public class ListRead implements Read { private static final Logger logger = LoggerFactory.getLogger(ListRead.class); + private final Function<? super CommandStore, AsyncExecutor> executor; public final Seekables<?, ?> readKeys; public final Seekables<?, ?> keys; - public ListRead(Seekables<?, ?> readKeys, Seekables<?, ?> keys) + public ListRead(Function<? super CommandStore, AsyncExecutor> executor, Seekables<?, ?> readKeys, Seekables<?, ?> keys) { + this.executor = executor; this.readKeys = readKeys; this.keys = keys; } @@ -55,32 +64,34 @@ public class ListRead implements Read public AsyncChain<Data> read(Seekable key, Txn.Kind kind, SafeCommandStore commandStore, Timestamp executeAt, DataStore store) { ListStore s = (ListStore)store; - ListData result = new ListData(); - switch (key.domain()) - { - default: throw new AssertionError(); - case Key: - int[] data = s.get((Key)key); - logger.trace("READ on {} at {} key:{} -> {}", s.node, executeAt, key, data); - result.put((Key)key, data); - break; - case Range: - for (Map.Entry<Key, int[]> e : s.get((Range)key)) - result.put(e.getKey(), e.getValue()); - } - return AsyncChains.success(result); + return executor.apply(commandStore.commandStore()).submit(() -> { + ListData result = new ListData(); + switch (key.domain()) + { + default: throw new AssertionError(); + case Key: + int[] data = s.get((Key)key); + logger.trace("READ on {} at {} key:{} -> {}", s.node, executeAt, key, data); + result.put((Key)key, data); + break; + case Range: + for (Map.Entry<Key, int[]> e : s.get((Range)key)) + result.put(e.getKey(), e.getValue()); + } + return result; + }); } @Override public Read slice(Ranges ranges) { - return new ListRead(readKeys.slice(ranges), keys.slice(ranges)); + return new ListRead(executor, readKeys.slice(ranges), keys.slice(ranges)); } @Override public Read merge(Read other) { - return new ListRead(((Seekables)readKeys).with(((ListRead)other).readKeys), ((Seekables)keys).with(((ListRead)other).keys)); + return new ListRead(executor, ((Seekables)readKeys).with(((ListRead)other).readKeys), ((Seekables)keys).with(((ListRead)other).keys)); } @Override diff --git a/accord-core/src/test/java/accord/impl/list/ListUpdate.java b/accord-core/src/test/java/accord/impl/list/ListUpdate.java index 055d1ea1..6461e383 100644 --- a/accord-core/src/test/java/accord/impl/list/ListUpdate.java +++ b/accord-core/src/test/java/accord/impl/list/ListUpdate.java @@ -21,17 +21,27 @@ package accord.impl.list; import java.util.Arrays; import java.util.Map; import java.util.TreeMap; +import java.util.function.Function; import java.util.stream.Collectors; -import accord.api.Key; import accord.api.Data; +import accord.api.Key; import accord.api.Update; -import accord.primitives.Ranges; +import accord.local.CommandStore; import accord.primitives.Keys; +import accord.primitives.Ranges; import accord.primitives.Seekables; +import accord.utils.async.AsyncExecutor; public class ListUpdate extends TreeMap<Key, Integer> implements Update { + private final Function<? super CommandStore, AsyncExecutor> executor; + + public ListUpdate(Function<? super CommandStore, AsyncExecutor> executor) + { + this.executor = executor; + } + @Override public Seekables<?, ?> keys() { @@ -41,7 +51,7 @@ public class ListUpdate extends TreeMap<Key, Integer> implements Update @Override public ListWrite apply(Data read) { - ListWrite write = new ListWrite(); + ListWrite write = new ListWrite(executor); Map<Key, int[]> data = (ListData)read; for (Map.Entry<Key, Integer> e : entrySet()) write.put(e.getKey(), append(data.get(e.getKey()), e.getValue())); @@ -51,7 +61,7 @@ public class ListUpdate extends TreeMap<Key, Integer> implements Update @Override public Update slice(Ranges ranges) { - ListUpdate result = new ListUpdate(); + ListUpdate result = new ListUpdate(executor); for (Map.Entry<Key, Integer> e : entrySet()) { if (ranges.contains(e.getKey())) @@ -63,7 +73,7 @@ public class ListUpdate extends TreeMap<Key, Integer> implements Update @Override public Update merge(Update other) { - ListUpdate result = new ListUpdate(); + ListUpdate result = new ListUpdate(executor); result.putAll(this); result.putAll((ListUpdate) other); return result; diff --git a/accord-core/src/test/java/accord/impl/list/ListWrite.java b/accord-core/src/test/java/accord/impl/list/ListWrite.java index 20a2a876..41fefe91 100644 --- a/accord-core/src/test/java/accord/impl/list/ListWrite.java +++ b/accord-core/src/test/java/accord/impl/list/ListWrite.java @@ -20,34 +20,47 @@ package accord.impl.list; import java.util.Arrays; import java.util.TreeMap; +import java.util.function.Function; import java.util.stream.Collectors; -import accord.api.Key; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import accord.api.DataStore; +import accord.api.Key; import accord.api.Write; +import accord.local.CommandStore; import accord.local.SafeCommandStore; import accord.primitives.Seekable; import accord.primitives.Timestamp; import accord.primitives.Writes; import accord.utils.Timestamped; import accord.utils.async.AsyncChain; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; +import accord.utils.async.AsyncExecutor; public class ListWrite extends TreeMap<Key, int[]> implements Write { private static final Logger logger = LoggerFactory.getLogger(ListWrite.class); + private final Function<? super CommandStore, AsyncExecutor> executor; + + public ListWrite(Function<? super CommandStore, AsyncExecutor> executor) + { + this.executor = executor; + } + @Override public AsyncChain<Void> apply(Seekable key, SafeCommandStore commandStore, Timestamp executeAt, DataStore store) { ListStore s = (ListStore) store; if (!containsKey(key)) return Writes.SUCCESS; - int[] data = get(key); - s.data.merge((Key)key, new Timestamped<>(executeAt, data), Timestamped::merge); - logger.trace("WRITE on {} at {} key:{} -> {}", s.node, executeAt, key, data); - return Writes.SUCCESS; + return executor.apply(commandStore.commandStore()).submit(() -> { + int[] data = get(key); + s.data.merge((Key)key, new Timestamped<>(executeAt, data), Timestamped::merge); + logger.trace("WRITE on {} at {} key:{} -> {}", s.node, executeAt, key, data); + return null; + }); } @Override diff --git a/accord-core/src/test/java/accord/local/CheckedCommands.java b/accord-core/src/test/java/accord/local/CheckedCommands.java new file mode 100644 index 00000000..a8be25ea --- /dev/null +++ b/accord-core/src/test/java/accord/local/CheckedCommands.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package accord.local; + +import javax.annotation.Nullable; + +import accord.api.Result; +import accord.api.RoutingKey; +import accord.primitives.Ballot; +import accord.primitives.PartialDeps; +import accord.primitives.PartialRoute; +import accord.primitives.PartialTxn; +import accord.primitives.Route; +import accord.primitives.Seekables; +import accord.primitives.Timestamp; +import accord.primitives.TxnId; +import accord.primitives.Writes; + +public class CheckedCommands +{ + public static void preaccept(SafeCommandStore safeStore, TxnId txnId, PartialTxn partialTxn, Route<?> route, @Nullable RoutingKey progressKey) + { + Commands.AcceptOutcome result = Commands.preaccept(safeStore, txnId, partialTxn, route, progressKey); + if (result != Commands.AcceptOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result); + } + + public static void accept(SafeCommandStore safeStore, TxnId txnId, Ballot ballot, PartialRoute<?> route, Seekables<?, ?> keys, @Nullable RoutingKey progressKey, Timestamp executeAt, PartialDeps partialDeps) + { + Commands.AcceptOutcome result = Commands.accept(safeStore, txnId, ballot, route, keys, progressKey, executeAt, partialDeps); + if (result != Commands.AcceptOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result); + } + + public static void commit(SafeCommandStore safeStore, TxnId txnId, Route<?> route, @Nullable RoutingKey progressKey, @Nullable PartialTxn partialTxn, Timestamp executeAt, PartialDeps partialDeps) + { + Commands.CommitOutcome result = Commands.commit(safeStore, txnId, route, progressKey, partialTxn, executeAt, partialDeps); + if (result != Commands.CommitOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + result); + } + + public static void apply(SafeCommandStore safeStore, TxnId txnId, long untilEpoch, Route<?> route, Timestamp executeAt, @Nullable PartialDeps partialDeps, Writes writes, Result result) + { + Commands.ApplyOutcome outcome = Commands.apply(safeStore, txnId, untilEpoch, route, executeAt, partialDeps, writes, result); + if (outcome != Commands.ApplyOutcome.Success) throw new IllegalStateException("Command mutation rejected: " + outcome); + } +} diff --git a/accord-core/src/test/java/accord/messages/PreAcceptTest.java b/accord-core/src/test/java/accord/messages/PreAcceptTest.java index ae339a1e..29be8036 100644 --- a/accord-core/src/test/java/accord/messages/PreAcceptTest.java +++ b/accord-core/src/test/java/accord/messages/PreAcceptTest.java @@ -59,29 +59,10 @@ public class PreAcceptTest private static final Id ID3 = id(3); private static final List<Id> IDS = listOf(ID1, ID2, ID3); private static final Topology TOPOLOGY = TopologyFactory.toTopology(IDS, 3, IntKey.range(0, 100)); - private static final Ranges RANGE = Ranges.single(IntKey.range(0, 100)); private static final Ranges FULL_RANGE = Ranges.single(IntKey.range(routing(Integer.MIN_VALUE), routing(Integer.MAX_VALUE))); private static final ReplyContext REPLY_CONTEXT = Network.replyCtxFor(0); - private static Node createNode(Id nodeId, MessageSink messageSink, Clock clock) - { - MockStore store = new MockStore(); - Scheduler scheduler = new ThreadPoolScheduler(); - return new Node(nodeId, - messageSink, - new MockConfigurationService(messageSink, EpochFunction.noop(), TOPOLOGY), - clock, - () -> store, - new ShardDistributor.EvenSplit(8, ignore -> new IntKey.Splitter()), - new TestAgent(), - new DefaultRandom(), - scheduler, - SizeOfIntersectionSorter.SUPPLIER, - SimpleProgressLog::new, - InMemoryCommandStores.Synchronized::new); - } - private static PreAccept preAccept(TxnId txnId, Txn txn, RoutingKey homeKey) { FullRoute<?> route = txn.keys().toRoute(homeKey); @@ -98,7 +79,7 @@ public class PreAcceptTest { RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE); Clock clock = new Clock(100); - Node node = createNode(ID1, messageSink, clock); + Node node = createNode(ID1, TOPOLOGY, messageSink, clock); messageSink.clearHistory(); try @@ -137,7 +118,7 @@ public class PreAcceptTest { RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE); Clock clock = new Clock(100); - Node node = createNode(ID1, messageSink, clock); + Node node = createNode(ID1, TOPOLOGY, messageSink, clock); try { Raw key = IntKey.key(10); @@ -165,7 +146,7 @@ public class PreAcceptTest { RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE); Clock clock = new Clock(100); - Node node = createNode(ID1, messageSink, clock); + Node node = createNode(ID1, TOPOLOGY, messageSink, clock); try { Raw key1 = IntKey.key(10); @@ -201,7 +182,7 @@ public class PreAcceptTest { RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE); Clock clock = new Clock(100); - Node node = createNode(ID1, messageSink, clock); + Node node = createNode(ID1, TOPOLOGY, messageSink, clock); messageSink.clearHistory(); Raw key = IntKey.key(10); try @@ -228,7 +209,7 @@ public class PreAcceptTest { RecordingMessageSink messageSink = new RecordingMessageSink(ID1, Network.BLACK_HOLE); Clock clock = new Clock(100); - Node node = createNode(ID1, messageSink, clock); + Node node = createNode(ID1, TOPOLOGY, messageSink, clock); try { diff --git a/accord-core/src/test/java/accord/messages/ReadDataTest.java b/accord-core/src/test/java/accord/messages/ReadDataTest.java new file mode 100644 index 00000000..8a6211ef --- /dev/null +++ b/accord-core/src/test/java/accord/messages/ReadDataTest.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package accord.messages; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Collectors; + +import org.junit.jupiter.api.Test; + +import accord.Utils; +import accord.api.Data; +import accord.api.Key; +import accord.api.MessageSink; +import accord.api.Query; +import accord.api.Read; +import accord.api.Result; +import accord.api.RoutingKey; +import accord.api.Update; +import accord.api.Write; +import accord.impl.IntKey; +import accord.impl.TopologyFactory; +import accord.impl.mock.MockCluster; +import accord.local.CheckedCommands; +import accord.local.Command; +import accord.local.CommandStore; +import accord.local.Node; +import accord.local.PreLoadContext; +import accord.local.SafeCommand; +import accord.primitives.Ballot; +import accord.primitives.FullRoute; +import accord.primitives.Keys; +import accord.primitives.PartialDeps; +import accord.primitives.PartialRoute; +import accord.primitives.PartialTxn; +import accord.primitives.Range; +import accord.primitives.Ranges; +import accord.primitives.Routable; +import accord.primitives.Timestamp; +import accord.primitives.Txn; +import accord.primitives.TxnId; +import accord.primitives.Writes; +import accord.topology.Topologies; +import accord.topology.Topology; +import accord.utils.async.AsyncChain; +import accord.utils.async.AsyncChains; +import accord.utils.async.AsyncResults; +import org.assertj.core.api.Assertions; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static accord.Utils.createNode; +import static accord.Utils.id; +import static accord.utils.Utils.listOf; +import static org.mockito.ArgumentMatchers.any; + +class ReadDataTest +{ + private static final Node.Id ID1 = id(1); + private static final Node.Id ID2 = id(2); + private static final Node.Id ID3 = id(3); + private static final List<Node.Id> IDS = listOf(ID1, ID2, ID3); + private static final Range RANGE = IntKey.range(0, 100); + private static final Ranges RANGES = Ranges.single(RANGE); + private static final Topology TOPOLOGY = TopologyFactory.toTopology(IDS, 3, RANGE); + private static final Topologies TOPOLOGIES = Utils.topologies(TOPOLOGY); + + private void test(Consumer<State> fn) + { + MessageSink sink = Mockito.mock(MessageSink.class); + Node node = createNode(ID1, TOPOLOGY, sink, new MockCluster.Clock(100)); + + TxnId txnId = node.nextTxnId(Txn.Kind.Write, Routable.Domain.Key); + Keys keys = Keys.of(IntKey.key(1), IntKey.key(43)); + + AsyncResults.SettableResult<Data> readResult = new AsyncResults.SettableResult<>(); + + Read read = Mockito.mock(Read.class); + Mockito.when(read.slice(any())).thenReturn(read); + Mockito.when(read.merge(any())).thenReturn(read); + Mockito.when(read.read(any(), any(), any(), any(), any())).thenAnswer(new Answer<AsyncChain<Data>>() + { + private boolean called = false; + @Override + public AsyncChain<Data> answer(InvocationOnMock ignore) throws Throwable + { + if (called) throw new IllegalStateException("Multiple calls"); + return readResult; + } + }); + Query query = Mockito.mock(Query.class); + Update update = Mockito.mock(Update.class); + Mockito.when(update.slice(any())).thenReturn(update); + + Txn txn = new Txn.InMemory(keys, read, query, update); + PartialTxn partialTxn = txn.slice(RANGES, true); + + fn.accept(new State(node, sink, txnId, partialTxn, readResult)); + } + + @Test + public void readyToExecuteObsoleteFromTracker() + { + // status=ReadyToExecute, so read will happen right away; obsolete marked by ObsoleteTracker + test(state -> { + state.readyToExecute(); + + ReplyContext replyContext = state.process(); + Mockito.verifyNoInteractions(state.sink); + + state.apply(); + state.readResult.setSuccess(Mockito.mock(Data.class)); + Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant)); + }); + } + + @Test + public void commitObsoleteFromTracker() + { + // status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener + test(state -> { + state.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> { + CheckedCommands.preaccept(safe, state.txnId, state.partialTxn, state.route, state.progressKey); + CheckedCommands.accept(safe, state.txnId, Ballot.ZERO, state.partialRoute, state.partialTxn.keys(), state.progressKey, state.executeAt, state.deps); + + SafeCommand safeCommand = safe.command(state.txnId); + safeCommand.commit(safeCommand.current(), state.executeAt, Command.WaitingOn.EMPTY); + }))); + + ReplyContext replyContext = state.process(); + + Mockito.verifyNoInteractions(state.sink); + + state.apply(); + state.readResult.setSuccess(Mockito.mock(Data.class)); + + Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant)); + }); + } + + @Test + public void mapReduceMarksObsolete() + { + // status=Commit, will listen waiting for ReadyToExecute; obsolete marked by status listener + test(state -> { + List<CommandStore> stores = stores(state); + // this test is a bit implementation specific... so if implementations change this may need an update + // since mapReduceConsume walks the store in id order, by making sure the stores involved in this test + // are in the "right" order, can make sure to hit a very specific edge case + Collections.sort(stores, Comparator.comparingInt(CommandStore::id)); + CommandStore store = stores.get(0); + + // ack doesn't get called due to waitingOnCount not being -1, can only happen once + // the process command completes + state.readResult.setSuccess(Mockito.mock(Data.class)); + state.readyToExecute(store); + + store = stores.get(1); + check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> { + SafeCommand command = safe.command(state.txnId); + command.commitInvalidated(command.current(), state.executeAt); + })); + + ReplyContext replyContext = state.process(); + + Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant)); + }); + } + + @Test + public void mapReduceAllStageMarksObsolete() + { + test(state -> { + List<CommandStore> stores = stores(state); + stores.forEach(store -> check(store.execute(PreLoadContext.contextFor(state.txnId, state.keys), safe -> { + SafeCommand command = safe.command(state.txnId); + command.commitInvalidated(command.current(), state.executeAt); + }))); + ReplyContext replyContext = state.process(); + + Mockito.verify(state.sink).reply(Mockito.eq(state.node.id()), Mockito.eq(replyContext), Mockito.eq(ReadData.ReadNack.Redundant)); + }); + } + + private static List<CommandStore> stores(State state) + { + List<CommandStore> stores = new ArrayList<>(2); + state.forEach(stores::add); + Assertions.assertThat(stores).hasSize(2); + // block duplicate stores + Map<Integer, Long> counts = stores.stream().map(CommandStore::id).collect(Collectors.groupingBy(Function.identity(), Collectors.counting())); + for (Map.Entry<Integer, Long> e : counts.entrySet()) + { + if (e.getValue() == 1) continue; + throw new AssertionError("Duplicate command store detected with id: " + e.getKey()); + } + return stores; + } + + private static void check(AsyncChain<Void> execute) + { + try + { + AsyncChains.getUninterruptibly(execute); + } + catch (ExecutionException e) + { + throw new AssertionError(e.getCause()); + } + } + + private static class State + { + private final Node node; + private final MessageSink sink; + private final TxnId txnId; + private final PartialTxn partialTxn; + private final Keys keys; + private final Key key; + private final FullRoute<?> route; + private final PartialRoute<?> partialRoute; + private final RoutingKey progressKey; + private final Timestamp executeAt; + private final PartialDeps deps; + private final AsyncResults.SettableResult<Data> readResult; + + State(Node node, MessageSink sink, TxnId txnId, PartialTxn partialTxn, AsyncResults.SettableResult<Data> readResult) + { + this.node = node; + this.sink = sink; + this.txnId = txnId; + this.partialTxn = partialTxn; + this.keys = (Keys) partialTxn.keys(); + this.key = keys.get(0); + this.route = keys.toRoute(key.toUnseekable()); + this.partialRoute = route.slice(RANGES); + this.progressKey = key.toUnseekable(); + this.executeAt = txnId; + this.deps = PartialDeps.builder(RANGES).build(); + this.readResult = readResult; + } + + void readyToExecute(CommandStore store) + { + check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> { + CheckedCommands.preaccept(safe, txnId, partialTxn, route, progressKey); + CheckedCommands.accept(safe, txnId, Ballot.ZERO, partialRoute, partialTxn.keys(), progressKey, executeAt, deps); + CheckedCommands.commit(safe, txnId, route, progressKey, partialTxn, executeAt, deps); + })); + } + + void readyToExecute() + { + forEach(this::readyToExecute); + } + + private void forEach(Consumer<CommandStore> fn) + { + keys.stream().map(node.commandStores()::unsafeForKey).distinct().forEach(fn); + } + + AsyncResults.SettableResult<Void> apply() + { + AsyncResults.SettableResult<Void> writeResult = new AsyncResults.SettableResult<>(); + Write write = Mockito.mock(Write.class); + Mockito.when(write.apply(any(), any(), any(), any())).thenReturn(writeResult); + Writes writes = new Writes(executeAt, keys, write); + + forEach(store -> check(store.execute(PreLoadContext.contextFor(txnId, keys), safe -> { + CheckedCommands.apply(safe, txnId, safe.latestEpoch(), route, executeAt, deps, writes, Mockito.mock(Result.class)); + }))); + return writeResult; + } + + ReplyContext process() + { + ReplyContext replyContext = Mockito.mock(ReplyContext.class); + ReadData readData = new ReadData(node.id(), TOPOLOGIES, txnId, keys, txnId); + readData.process(node, node.id(), replyContext); + return replyContext; + } + } +} \ No newline at end of file diff --git a/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java b/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java index 98e6217e..8ae2b895 100644 --- a/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java +++ b/accord-core/src/test/java/accord/verify/StrictSerializabilityVerifier.java @@ -610,6 +610,8 @@ public class StrictSerializabilityVerifier { if (maybeWrite >= 0) { + if (IntStream.of(sequence).anyMatch(i -> i == maybeWrite)) + throw new HistoryViolation(key, "Attempted to write " + maybeWrite + " which is already found in the seq; seq=" + Arrays.toString(sequence)); sequence = Arrays.copyOf(sequence, sequence.length + 1); sequence[sequence.length - 1] = maybeWrite; } diff --git a/buildSrc/src/main/groovy/accord.java-conventions.gradle b/buildSrc/src/main/groovy/accord.java-conventions.gradle index 9e663ee8..5817cc9c 100644 --- a/buildSrc/src/main/groovy/accord.java-conventions.gradle +++ b/buildSrc/src/main/groovy/accord.java-conventions.gradle @@ -47,6 +47,7 @@ dependencies { testRuntimeOnly 'org.junit.jupiter:junit-jupiter-engine:5.7.0' testImplementation group: 'org.assertj', name: 'assertj-core', version: '3.24.2' + testImplementation group: 'org.mockito', name: 'mockito-core', version: '4.7.0' } task copyMainDependencies(type: Copy) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org