This is an automated email from the ASF dual-hosted git repository.

samt pushed a commit to branch cep-21-tcm
in repository https://gitbox.apache.org/repos/asf/cassandra.git

commit 56d289b4fbe104cbf9a3320c0ed7b93991dcaa2d
Author: Sam Tunnicliffe <s...@apache.org>
AuthorDate: Wed Mar 1 10:03:44 2023 +0000

    [CEP-21] Include current epoch in internode header
    
    Co-authored-by: Marcus Eriksson <marc...@apache.org>
    Co-authored-by: Alex Petrov <oleksandr.pet...@gmail.com>
    Co-authored-by: Sam Tunnicliffe <s...@apache.org>
---
 .../org/apache/cassandra/db/CounterMutation.java   |   6 +
 src/java/org/apache/cassandra/db/Mutation.java     |   5 +
 src/java/org/apache/cassandra/net/Message.java     | 130 +++++++++++++++++----
 .../org/apache/cassandra/net/MessagingService.java |   6 +-
 4 files changed, 121 insertions(+), 26 deletions(-)

diff --git a/src/java/org/apache/cassandra/db/CounterMutation.java 
b/src/java/org/apache/cassandra/db/CounterMutation.java
index 4f91b83ca2..86086902aa 100644
--- a/src/java/org/apache/cassandra/db/CounterMutation.java
+++ b/src/java/org/apache/cassandra/db/CounterMutation.java
@@ -51,6 +51,7 @@ import static java.util.concurrent.TimeUnit.*;
 import static org.apache.cassandra.net.MessagingService.VERSION_30;
 import static org.apache.cassandra.net.MessagingService.VERSION_3014;
 import static org.apache.cassandra.net.MessagingService.VERSION_40;
+import static org.apache.cassandra.net.MessagingService.VERSION_50;
 import static org.apache.cassandra.utils.Clock.Global.nanoTime;
 
 public class CounterMutation implements IMutation
@@ -334,6 +335,7 @@ public class CounterMutation implements IMutation
     private int serializedSize30;
     private int serializedSize3014;
     private int serializedSize40;
+    private int serializedSize50;
 
     public int serializedSize(int version)
     {
@@ -351,6 +353,10 @@ public class CounterMutation implements IMutation
                 if (serializedSize40 == 0)
                     serializedSize40 = (int) serializer.serializedSize(this, 
VERSION_40);
                 return serializedSize40;
+            case VERSION_50:
+                if (serializedSize50 == 0)
+                    serializedSize50 = (int) serializer.serializedSize(this, 
VERSION_50);
+                return serializedSize50;
             default:
                 throw new IllegalStateException("Unknown serialization 
version: " + version);
         }
diff --git a/src/java/org/apache/cassandra/db/Mutation.java 
b/src/java/org/apache/cassandra/db/Mutation.java
index ad43b16d48..9a5cdabb50 100644
--- a/src/java/org/apache/cassandra/db/Mutation.java
+++ b/src/java/org/apache/cassandra/db/Mutation.java
@@ -317,6 +317,7 @@ public class Mutation implements IMutation, 
Supplier<Mutation>
     private int serializedSize30;
     private int serializedSize3014;
     private int serializedSize40;
+    private int serializedSize50;
 
     public int serializedSize(int version)
     {
@@ -334,6 +335,10 @@ public class Mutation implements IMutation, 
Supplier<Mutation>
                 if (serializedSize40 == 0)
                     serializedSize40 = (int) serializer.serializedSize(this, 
VERSION_40);
                 return serializedSize40;
+            case VERSION_50:
+                if (serializedSize50 == 0)
+                    serializedSize50 = (int) serializer.serializedSize(this, 
VERSION_50);
+                return serializedSize50;
             default:
                 throw new IllegalStateException("Unknown serialization 
version: " + version);
         }
diff --git a/src/java/org/apache/cassandra/net/Message.java 
b/src/java/org/apache/cassandra/net/Message.java
index fa14134933..03562540b4 100644
--- a/src/java/org/apache/cassandra/net/Message.java
+++ b/src/java/org/apache/cassandra/net/Message.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
 import com.google.common.primitives.Ints;
@@ -40,11 +41,11 @@ import org.apache.cassandra.io.util.DataInputBuffer;
 import org.apache.cassandra.io.util.DataInputPlus;
 import org.apache.cassandra.io.util.DataOutputPlus;
 import org.apache.cassandra.locator.InetAddressAndPort;
+import org.apache.cassandra.tcm.Epoch;
+import org.apache.cassandra.tcm.ClusterMetadata;
 import org.apache.cassandra.tracing.Tracing;
 import org.apache.cassandra.tracing.Tracing.TraceType;
-import org.apache.cassandra.utils.MonotonicClockTranslation;
-import org.apache.cassandra.utils.NoSpamLogger;
-import org.apache.cassandra.utils.TimeUUID;
+import org.apache.cassandra.utils.*;
 
 import static java.util.concurrent.TimeUnit.MINUTES;
 import static java.util.concurrent.TimeUnit.NANOSECONDS;
@@ -67,6 +68,8 @@ public class Message<T>
     private static final Logger logger = 
LoggerFactory.getLogger(Message.class);
     private static final NoSpamLogger noSpam1m = 
NoSpamLogger.getLogger(logger, 1, TimeUnit.MINUTES);
 
+    private static final Supplier<Epoch> epochSupplier = () -> 
ClusterMetadata.current().epoch;
+
     public final Header header;
     public final T payload;
 
@@ -97,6 +100,11 @@ public class Message<T>
         return header.id;
     }
 
+    public Epoch epoch()
+    {
+        return header.epoch;
+    }
+
     public Verb verb()
     {
         return header.verb;
@@ -190,14 +198,14 @@ public class Message<T>
      */
     public static <T> Message<T> out(Verb verb, T payload)
     {
-        assert !verb.isResponse();
+        assert !verb.isResponse() : verb;
 
         return outWithParam(nextId(), verb, payload, null, null);
     }
 
     public static <T> Message<T> synthetic(InetAddressAndPort from, Verb verb, 
T payload)
     {
-        return new Message<>(new Header(-1, verb, from, -1, -1, 0, NO_PARAMS), 
payload);
+        return new Message<>(new Header(-1, epochSupplier.get(), verb, from, 
-1, -1, 0, NO_PARAMS), payload);
     }
 
     public static <T> Message<T> out(Verb verb, T payload, long expiresAtNanos)
@@ -242,7 +250,7 @@ public class Message<T>
         if (expiresAtNanos == 0)
             expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
 
-        return new Message<>(new Header(id, verb, from, createdAtNanos, 
expiresAtNanos, flags, buildParams(paramType, paramValue)), payload);
+        return new Message<>(new Header(id, epochSupplier.get(), verb, from, 
createdAtNanos, expiresAtNanos, flags, buildParams(paramType, paramValue)), 
payload);
     }
 
     public static <T> Message<T> internalResponse(Verb verb, T payload)
@@ -260,7 +268,20 @@ public class Message<T>
         assert verb.isResponse();
         long createdAtNanos = approxTime.now();
         long expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
-        return new Message<>(new Header(0, verb, from, createdAtNanos, 
expiresAtNanos, 0, NO_PARAMS), payload);
+        return new Message<>(new Header(0, epochSupplier.get(), verb, from, 
createdAtNanos, expiresAtNanos, 0, NO_PARAMS), payload);
+    }
+
+    /**
+     * A way to generate messages originating from arbitrary node. Should ONLY 
be used for testing purposes, as meddling with an
+     * originator can be dangerous.
+     */
+    @VisibleForTesting
+    public static <T> Message<T> remoteResponseForTests(long id, 
InetAddressAndPort from, Verb verb, T payload)
+    {
+        assert verb.isResponse();
+        long createdAtNanos = approxTime.now();
+        long expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
+        return new Message<>(new Header(id, epochSupplier.get(), verb, from, 
createdAtNanos, expiresAtNanos, 0, NO_PARAMS), payload);
     }
 
     /** Builds a response Message with provided payload, and all the right 
fields inferred from request Message */
@@ -306,6 +327,11 @@ public class Message<T>
         return new Message<>(header.withFlag(flag), payload);
     }
 
+    public Message<T> withEpoch(Epoch epoch)
+    {
+        return new Message<>(header.withEpoch(epoch), payload);
+    }
+
     public Message<T> withParam(ParamType type, Object value)
     {
         return new Message<>(header.withParam(type, value), payload);
@@ -416,6 +442,7 @@ public class Message<T>
     public static class Header
     {
         public final long id;
+        public final Epoch epoch;
         public final Verb verb;
         public final InetAddressAndPort from;
         public final long createdAtNanos;
@@ -423,9 +450,10 @@ public class Message<T>
         private final int flags;
         private final Map<ParamType, Object> params;
 
-        private Header(long id, Verb verb, InetAddressAndPort from, long 
createdAtNanos, long expiresAtNanos, int flags, Map<ParamType, Object> params)
+        private Header(long id, Epoch epoch, Verb verb, InetAddressAndPort 
from, long createdAtNanos, long expiresAtNanos, int flags, Map<ParamType, 
Object> params)
         {
             this.id = id;
+            this.epoch = epoch;
             this.verb = verb;
             this.from = from;
             this.expiresAtNanos = expiresAtNanos;
@@ -436,17 +464,22 @@ public class Message<T>
 
         Header withFlag(MessageFlag flag)
         {
-            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, 
flag.addTo(flags), params);
+            return new Header(id, epoch, verb, from, createdAtNanos, 
expiresAtNanos, flag.addTo(flags), params);
+        }
+
+        Header withEpoch(Epoch epoch)
+        {
+            return new Header(id, epoch, verb, from, createdAtNanos, 
expiresAtNanos, flags, params);
         }
 
         Header withParam(ParamType type, Object value)
         {
-            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, 
flags, addParam(params, type, value));
+            return new Header(id, epoch, verb, from, createdAtNanos, 
expiresAtNanos, flags, addParam(params, type, value));
         }
 
         Header withParams(Map<ParamType, Object> values)
         {
-            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, 
flags, addParams(params, values));
+            return new Header(id, epoch, verb, from, createdAtNanos, 
expiresAtNanos, flags, addParams(params, values));
         }
 
         boolean callBackOnFailure()
@@ -513,6 +546,7 @@ public class Message<T>
         private long createdAtNanos;
         private long expiresAtNanos;
         private long id;
+        private Epoch epoch;
 
         private boolean hasId;
 
@@ -614,6 +648,12 @@ public class Message<T>
             return this;
         }
 
+        public Builder<T> withEpoch(Epoch epoch)
+        {
+            this.epoch = epoch;
+            return this;
+        }
+
         public Message<T> build()
         {
             if (verb == null)
@@ -622,8 +662,10 @@ public class Message<T>
                 throw new IllegalArgumentException();
             if (payload == null)
                 throw new IllegalArgumentException();
+            if (epoch == null)
+                epoch = epochSupplier.get();
 
-            return new Message<>(new Header(hasId ? id : nextId(), verb, from, 
createdAtNanos, expiresAtNanos, flags, params), payload);
+            return new Message<>(new Header(hasId ? id : nextId(), epoch, 
verb, from, createdAtNanos, expiresAtNanos, flags, params), payload);
         }
     }
 
@@ -636,6 +678,7 @@ public class Message<T>
                                .withExpiresAt(message.expiresAtNanos())
                                .withFlags(message.header.flags)
                                .withParams(message.header.params)
+                               .withEpoch(message.header.epoch)
                                .withPayload(message.payload);
     }
 
@@ -674,6 +717,8 @@ public class Message<T>
      * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
      * | Message ID (vint)             |
      * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
+     * | Epoch (vint)                  |
+     * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
      * | Creation timestamp (int)      |
      * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
      * | Expiry (vint)                 |
@@ -738,7 +783,7 @@ public class Message<T>
          */
         int inferMessageSize(ByteBuffer buf, int index, int limit, int 
version) throws InvalidLegacyProtocolMagic
         {
-            int size = version >= VERSION_40 ? inferMessageSizePost40(buf, 
index, limit) : inferMessageSizePre40(buf, index, limit);
+            int size = version >= VERSION_40 ? inferMessageSizePost40(buf, 
index, limit, version) : inferMessageSizePre40(buf, index, limit);
             if (size > DatabaseDescriptor.getInternodeMaxMessageSizeInBytes())
                 throw new OversizedMessageException(size);
             return size;
@@ -773,6 +818,9 @@ public class Message<T>
         private void serializeHeaderPost40(Header header, DataOutputPlus out, 
int version) throws IOException
         {
             out.writeUnsignedVInt(header.id);
+            if (version >= VERSION_50)
+                Epoch.messageSerializer.serialize(header.epoch, out, version);
+
             // int cast cuts off the high-order half of the timestamp, which 
we can assume remains
             // the same between now and when the recipient reconstructs it.
             out.writeInt((int) 
approxTime.translate().toMillisSinceEpoch(header.createdAtNanos));
@@ -785,6 +833,10 @@ public class Message<T>
         private Header deserializeHeaderPost40(DataInputPlus in, 
InetAddressAndPort peer, int version) throws IOException
         {
             long id = in.readUnsignedVInt();
+            Epoch epoch = Epoch.EMPTY;
+            if (version >= VERSION_50)
+                epoch = Epoch.messageSerializer.deserialize(in, version);
+
             long currentTimeNanos = approxTime.now();
             MonotonicClockTranslation timeSnapshot = approxTime.translate();
             long creationTimeNanos = calculateCreationTimeNanos(in.readInt(), 
timeSnapshot, currentTimeNanos);
@@ -792,12 +844,16 @@ public class Message<T>
             Verb verb = Verb.fromId(in.readUnsignedVInt32());
             int flags = in.readUnsignedVInt32();
             Map<ParamType, Object> params = deserializeParams(in, version);
-            return new Header(id, verb, peer, creationTimeNanos, 
expiresAtNanos, flags, params);
+            return new Header(id, epoch, verb, peer, creationTimeNanos, 
expiresAtNanos, flags, params);
         }
 
-        private void skipHeaderPost40(DataInputPlus in) throws IOException
+        private void skipHeaderPost40(DataInputPlus in, int version) throws 
IOException
         {
             skipUnsignedVInt(in); // id
+            if (version >= VERSION_50)
+            {
+                skipUnsignedVInt(in); // epoch
+            }
             in.skipBytesFully(4); // createdAt
             skipUnsignedVInt(in); // expiresIn
             skipUnsignedVInt(in); // verb
@@ -809,6 +865,8 @@ public class Message<T>
         {
             long size = 0;
             size += sizeofUnsignedVInt(header.id);
+            if (version >= VERSION_50)
+                size += sizeofUnsignedVInt(header.epoch.getEpoch());
             size += CREATION_TIME_SIZE;
             size += 
sizeofUnsignedVInt(NANOSECONDS.toMillis(header.expiresAtNanos - 
header.createdAtNanos));
             size += sizeofUnsignedVInt(header.verb.id);
@@ -826,6 +884,13 @@ public class Message<T>
             long id = getUnsignedVInt(buf, index);
             index += computeUnsignedVIntSize(id);
 
+            Epoch epoch = Epoch.EMPTY;
+            if (version >= VERSION_50)
+            {
+                long epochl = getUnsignedVInt(buf, index);
+                index += computeUnsignedVIntSize(epochl);
+                epoch = Epoch.create(epochl);
+            }
             int createdAtMillis = buf.getInt(index);
             index += sizeof(createdAtMillis);
 
@@ -843,7 +908,7 @@ public class Message<T>
             long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, 
timeSnapshot, currentTimeNanos);
             long expiresAtNanos = getExpiresAtNanos(createdAtNanos, 
currentTimeNanos, TimeUnit.MILLISECONDS.toNanos(expiresInMillis));
 
-            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, 
flags, params);
+            return new Header(id, epoch, verb, from, createdAtNanos, 
expiresAtNanos, flags, params);
         }
 
         private <T> void serializePost40(Message<T> message, DataOutputPlus 
out, int version) throws IOException
@@ -863,7 +928,7 @@ public class Message<T>
 
         private <T> Message<T> deserializePost40(DataInputPlus in, Header 
header, int version) throws IOException
         {
-            skipHeaderPost40(in);
+            skipHeaderPost40(in, version);
             skipUnsignedVInt(in); // payload size, not needed by payload 
deserializer
             T payload = (T) header.verb.serializer().deserialize(in, version);
             return new Message<>(header, payload);
@@ -878,7 +943,7 @@ public class Message<T>
             return Ints.checkedCast(size);
         }
 
-        private int inferMessageSizePost40(ByteBuffer buf, int readerIndex, 
int readerLimit)
+        private int inferMessageSizePost40(ByteBuffer buf, int readerIndex, 
int readerLimit, int version)
         {
             int index = readerIndex;
 
@@ -887,6 +952,13 @@ public class Message<T>
                 return -1; // not enough bytes to read id
             index += idSize;
 
+            if (version >= VERSION_50)
+            {
+                int epochSize = computeUnsignedVIntSize(buf, index, 
readerLimit);
+                if (epochSize < 0)
+                    return -1; // not enough bytes to read epoch
+                index += epochSize;
+            }
             index += CREATION_TIME_SIZE;
             if (index > readerLimit)
                 return -1;
@@ -946,7 +1018,7 @@ public class Message<T>
             Verb verb = Verb.fromId(in.readInt());
             Map<ParamType, Object> params = deserializeParams(in, version);
             int flags = removeFlagsFromLegacyParams(params);
-            return new Header(id, verb, from, creationTimeNanos, 
verb.expiresAtNanos(creationTimeNanos), flags, params);
+            return new Header(id, Epoch.EMPTY, verb, from, creationTimeNanos, 
verb.expiresAtNanos(creationTimeNanos), flags, params);
         }
 
         private static final int PRE_40_MESSAGE_PREFIX_SIZE = 12; // protocol 
magic + id + createdAt
@@ -955,7 +1027,7 @@ public class Message<T>
         {
             in.skipBytesFully(PRE_40_MESSAGE_PREFIX_SIZE); // magic, id, 
createdAt
             in.skipBytesFully(in.readByte());              // from
-            in.skipBytesFully(4);                          // verb
+            in.skipBytesFully(4);                       // verb
             skipParamsPre40(in);                           // params
         }
 
@@ -995,7 +1067,7 @@ public class Message<T>
             long createdAtNanos = calculateCreationTimeNanos(createdAtMillis, 
timeSnapshot, currentTimeNanos);
             long expiresAtNanos = verb.expiresAtNanos(createdAtNanos);
 
-            return new Header(id, verb, from, createdAtNanos, expiresAtNanos, 
flags, params);
+            return new Header(id, Epoch.EMPTY, verb, from, createdAtNanos, 
expiresAtNanos, flags, params);
         }
 
         private <T> void serializePre40(Message<T> message, DataOutputPlus 
out, int version) throws IOException
@@ -1113,7 +1185,7 @@ public class Message<T>
             params.put(ParamType.FAILURE_RESPONSE, LegacyFlag.instance);
             params.put(ParamType.FAILURE_REASON, post40.payload);
 
-            Header header = new Header(post40.id(), 
post40.verb().toPre40Verb(), post40.from(), post40.createdAtNanos(), 
post40.expiresAtNanos(), 0, params);
+            Header header = new Header(post40.id(), Epoch.EMPTY, 
post40.verb().toPre40Verb(), post40.from(), post40.createdAtNanos(), 
post40.expiresAtNanos(), 0, params);
             return new Message<>(header, NoPayload.noPayload);
         }
 
@@ -1128,7 +1200,7 @@ public class Message<T>
             if (null == reason)
                 reason = RequestFailureReason.UNKNOWN;
 
-            Header header = new Header(pre40.id(), Verb.FAILURE_RSP, 
pre40.from(), pre40.createdAtNanos(), pre40.expiresAtNanos(), 
pre40.header.flags, params);
+            Header header = new Header(pre40.id(), epochSupplier.get(), 
Verb.FAILURE_RSP, pre40.from(), pre40.createdAtNanos(), pre40.expiresAtNanos(), 
pre40.header.flags, params);
             return new Message<>(header, reason);
         }
 
@@ -1425,6 +1497,7 @@ public class Message<T>
     private int serializedSize30;
     private int serializedSize3014;
     private int serializedSize40;
+    private int serializedSize50;
 
     /**
      * Serialized size of the entire message, for the provided messaging 
version. Caches the calculated value.
@@ -1445,14 +1518,19 @@ public class Message<T>
                 if (serializedSize40 == 0)
                     serializedSize40 = serializer.serializedSize(this, 
VERSION_40);
                 return serializedSize40;
+            case VERSION_50:
+                if (serializedSize50 == 0)
+                    serializedSize50 = serializer.serializedSize(this, 
VERSION_50);
+                return serializedSize50;
             default:
-                throw new IllegalStateException();
+                throw new IllegalStateException("Illegal messaging version: " 
+ version);
         }
     }
 
     private int payloadSize30   = -1;
     private int payloadSize3014 = -1;
     private int payloadSize40   = -1;
+    private int payloadSize50   = -1;
 
     private int payloadSize(int version)
     {
@@ -1470,6 +1548,10 @@ public class Message<T>
                 if (payloadSize40 < 0)
                     payloadSize40 = serializer.payloadSize(this, VERSION_40);
                 return payloadSize40;
+            case VERSION_50:
+                if (payloadSize50 < 0)
+                    payloadSize50 = serializer.payloadSize(this, VERSION_50);
+                return payloadSize50;
             default:
                 throw new IllegalStateException();
         }
diff --git a/src/java/org/apache/cassandra/net/MessagingService.java 
b/src/java/org/apache/cassandra/net/MessagingService.java
index 86f8ebd5e4..6da2a6a425 100644
--- a/src/java/org/apache/cassandra/net/MessagingService.java
+++ b/src/java/org/apache/cassandra/net/MessagingService.java
@@ -210,8 +210,9 @@ public class MessagingService extends 
MessagingServiceMBeanImpl implements Messa
     public static final int VERSION_30 = 10;
     public static final int VERSION_3014 = 11;
     public static final int VERSION_40 = 12;
+    public static final int VERSION_50 = 13;
     public static final int minimum_version = VERSION_30;
-    public static final int current_version = VERSION_40;
+    public static final int current_version = VERSION_50;
     static AcceptVersions accept_messaging = new 
AcceptVersions(minimum_version, current_version);
     static AcceptVersions accept_streaming = new 
AcceptVersions(current_version, current_version);
     static Map<Integer, Integer> versionOrdinalMap = 
Arrays.stream(Version.values()).collect(Collectors.toMap(v -> v.value, v -> 
v.ordinal()));
@@ -236,7 +237,8 @@ public class MessagingService extends 
MessagingServiceMBeanImpl implements Messa
     {
         VERSION_30(10),
         VERSION_3014(11),
-        VERSION_40(12);
+        VERSION_40(12),
+        VERSION_50(13);
 
         public final int value;
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org
For additional commands, e-mail: commits-h...@cassandra.apache.org

Reply via email to