Repository: cassandra Updated Branches: refs/heads/cassandra-2.1 ece386439 -> 9872b74ef
http://git-wip-us.apache.org/repos/asf/cassandra/blob/9872b74e/src/java/org/apache/cassandra/transport/messages/BatchMessage.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/transport/messages/BatchMessage.java b/src/java/org/apache/cassandra/transport/messages/BatchMessage.java index ef30a22..ec96ed1 100644 --- a/src/java/org/apache/cassandra/transport/messages/BatchMessage.java +++ b/src/java/org/apache/cassandra/transport/messages/BatchMessage.java @@ -28,6 +28,7 @@ import io.netty.buffer.ByteBuf; import org.apache.cassandra.cql3.*; import org.apache.cassandra.cql3.statements.BatchStatement; import org.apache.cassandra.cql3.statements.ModificationStatement; +import org.apache.cassandra.cql3.statements.ParsedStatement; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.exceptions.InvalidRequestException; import org.apache.cassandra.exceptions.PreparedQueryNotFoundException; @@ -61,8 +62,11 @@ public class BatchMessage extends Message.Request throw new ProtocolException("Invalid query kind in BATCH messages. Must be 0 or 1 but got " + kind); variables.add(CBUtil.readValueList(body)); } - ConsistencyLevel consistency = CBUtil.readConsistencyLevel(body); - return new BatchMessage(toType(type), queryOrIds, variables, consistency); + QueryOptions options = version < 3 + ? QueryOptions.fromPreV3Batch(CBUtil.readConsistencyLevel(body)) + : QueryOptions.codec.decode(body, version); + + return new BatchMessage(toType(type), queryOrIds, variables, options); } public void encode(BatchMessage msg, ByteBuf dest, int version) @@ -84,7 +88,10 @@ public class BatchMessage extends Message.Request CBUtil.writeValueList(msg.values.get(i), dest); } - CBUtil.writeConsistencyLevel(msg.consistency, dest); + if (version < 3) + CBUtil.writeConsistencyLevel(msg.options.getConsistency(), dest); + else + QueryOptions.codec.encode(msg.options, dest, version); } public int encodedSize(BatchMessage msg, int version) @@ -99,7 +106,9 @@ public class BatchMessage extends Message.Request size += CBUtil.sizeOfValueList(msg.values.get(i)); } - size += CBUtil.sizeOfConsistencyLevel(msg.consistency); + size += version < 3 + ? CBUtil.sizeOfConsistencyLevel(msg.options.getConsistency()) + : QueryOptions.codec.encodedSize(msg.options, version); return size; } @@ -131,15 +140,15 @@ public class BatchMessage extends Message.Request public final BatchStatement.Type type; public final List<Object> queryOrIdList; public final List<List<ByteBuffer>> values; - public final ConsistencyLevel consistency; + public final QueryOptions options; - public BatchMessage(BatchStatement.Type type, List<Object> queryOrIdList, List<List<ByteBuffer>> values, ConsistencyLevel consistency) + public BatchMessage(BatchStatement.Type type, List<Object> queryOrIdList, List<List<ByteBuffer>> values, QueryOptions options) { super(Message.Type.BATCH); this.type = type; this.queryOrIdList = queryOrIdList; this.values = values; - this.consistency = consistency; + this.options = options; } public Message.Response execute(QueryState state) @@ -161,27 +170,39 @@ public class BatchMessage extends Message.Request } QueryHandler handler = state.getClientState().getCQLQueryHandler(); - List<ModificationStatement> statements = new ArrayList<ModificationStatement>(queryOrIdList.size()); + List<ParsedStatement.Prepared> prepared = new ArrayList<>(queryOrIdList.size()); for (int i = 0; i < queryOrIdList.size(); i++) { Object query = queryOrIdList.get(i); - CQLStatement statement; + ParsedStatement.Prepared p; if (query instanceof String) { - statement = QueryProcessor.parseStatement((String)query, state); + p = QueryProcessor.parseStatement((String)query, state); } else { - statement = handler.getPrepared((MD5Digest)query); - if (statement == null) + p = handler.getPrepared((MD5Digest)query); + if (p == null) throw new PreparedQueryNotFoundException((MD5Digest)query); } List<ByteBuffer> queryValues = values.get(i); - if (queryValues.size() != statement.getBoundTerms()) + if (queryValues.size() != p.statement.getBoundTerms()) throw new InvalidRequestException(String.format("There were %d markers(?) in CQL but %d bound variables", - statement.getBoundTerms(), + p.statement.getBoundTerms(), queryValues.size())); + + prepared.add(p); + } + + BatchQueryOptions batchOptions = BatchQueryOptions.withPerStatementVariables(options, values, queryOrIdList); + List<ModificationStatement> statements = new ArrayList<>(prepared.size()); + for (int i = 0; i < prepared.size(); i++) + { + ParsedStatement.Prepared p = prepared.get(i); + batchOptions.forStatement(i).prepare(p.boundNames); + CQLStatement statement = p.statement; + if (!(statement instanceof ModificationStatement)) throw new InvalidRequestException("Invalid statement in batch: only UPDATE, INSERT and DELETE statements are allowed."); @@ -202,7 +223,7 @@ public class BatchMessage extends Message.Request // Note: It's ok at this point to pass a bogus value for the number of bound terms in the BatchState ctor // (and no value would be really correct, so we prefer passing a clearly wrong one). BatchStatement batch = new BatchStatement(-1, type, statements, Attributes.none()); - Message.Response response = handler.processBatch(batch, state, new BatchQueryOptions(consistency, values, queryOrIdList)); + Message.Response response = handler.processBatch(batch, state, batchOptions); if (tracingId != null) response.setTracingId(tracingId); @@ -229,7 +250,7 @@ public class BatchMessage extends Message.Request if (i > 0) sb.append(", "); sb.append(queryOrIdList.get(i)).append(" with ").append(values.get(i).size()).append(" values"); } - sb.append("] at consistency ").append(consistency); + sb.append("] at consistency ").append(options.getConsistency()); return sb.toString(); } } http://git-wip-us.apache.org/repos/asf/cassandra/blob/9872b74e/src/java/org/apache/cassandra/transport/messages/EventMessage.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/transport/messages/EventMessage.java b/src/java/org/apache/cassandra/transport/messages/EventMessage.java index 890b9d1..f3ab526 100644 --- a/src/java/org/apache/cassandra/transport/messages/EventMessage.java +++ b/src/java/org/apache/cassandra/transport/messages/EventMessage.java @@ -28,17 +28,17 @@ public class EventMessage extends Message.Response { public EventMessage decode(ByteBuf body, int version) { - return new EventMessage(Event.deserialize(body)); + return new EventMessage(Event.deserialize(body, version)); } public void encode(EventMessage msg, ByteBuf dest, int version) { - msg.event.serialize(dest); + msg.event.serialize(dest, version); } public int encodedSize(EventMessage msg, int version) { - return msg.event.serializedSize(); + return msg.event.serializedSize(version); } }; http://git-wip-us.apache.org/repos/asf/cassandra/blob/9872b74e/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java ---------------------------------------------------------------------- diff --git a/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java b/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java index caec43f..d618f43 100644 --- a/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java +++ b/src/java/org/apache/cassandra/transport/messages/ExecuteMessage.java @@ -27,6 +27,7 @@ import io.netty.buffer.ByteBuf; import org.apache.cassandra.cql3.CQLStatement; import org.apache.cassandra.cql3.QueryHandler; import org.apache.cassandra.cql3.QueryOptions; +import org.apache.cassandra.cql3.statements.ParsedStatement; import org.apache.cassandra.db.ConsistencyLevel; import org.apache.cassandra.exceptions.PreparedQueryNotFoundException; import org.apache.cassandra.service.QueryState; @@ -100,7 +101,9 @@ public class ExecuteMessage extends Message.Request try { QueryHandler handler = state.getClientState().getCQLQueryHandler(); - CQLStatement statement = handler.getPrepared(statementId); + ParsedStatement.Prepared prepared = handler.getPrepared(statementId); + options.prepare(prepared.boundNames); + CQLStatement statement = prepared.statement; if (statement == null) throw new PreparedQueryNotFoundException(statementId); http://git-wip-us.apache.org/repos/asf/cassandra/blob/9872b74e/test/unit/org/apache/cassandra/transport/SerDeserTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/transport/SerDeserTest.java b/test/unit/org/apache/cassandra/transport/SerDeserTest.java new file mode 100644 index 0000000..9b66efb --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/SerDeserTest.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.transport; + +import java.nio.ByteBuffer; +import java.util.*; + +import io.netty.buffer.Unpooled; +import io.netty.buffer.ByteBuf; + +import org.junit.Test; +import org.apache.cassandra.cql3.*; +import org.apache.cassandra.db.ConsistencyLevel; +import org.apache.cassandra.db.marshal.*; +import org.apache.cassandra.serializers.CollectionSerializer; +import org.apache.cassandra.transport.Event.TopologyChange; +import org.apache.cassandra.transport.Event.SchemaChange; +import org.apache.cassandra.transport.Event.StatusChange; +import org.apache.cassandra.utils.ByteBufferUtil; +import org.apache.cassandra.utils.FBUtilities; +import org.apache.cassandra.utils.Pair; + +import static org.junit.Assert.assertEquals; +import static org.apache.cassandra.utils.ByteBufferUtil.bytes; + +/** + * Serialization/deserialization tests for protocol objects and messages. + */ +public class SerDeserTest +{ + @Test + public void collectionSerDeserTest() throws Exception + { + collectionSerDeserTest(2); + collectionSerDeserTest(3); + } + + public void collectionSerDeserTest(int version) throws Exception + { + // Lists + ListType<?> lt = ListType.getInstance(Int32Type.instance); + List<Integer> l = Arrays.asList(2, 6, 1, 9); + + List<ByteBuffer> lb = new ArrayList<>(l.size()); + for (Integer i : l) + lb.add(Int32Type.instance.decompose(i)); + + assertEquals(l, lt.getSerializer().deserializeForNativeProtocol(CollectionSerializer.pack(lb, lb.size(), version), version)); + + // Sets + SetType<?> st = SetType.getInstance(UTF8Type.instance); + Set<String> s = new LinkedHashSet<>(); + s.addAll(Arrays.asList("bar", "foo", "zee")); + + List<ByteBuffer> sb = new ArrayList<>(s.size()); + for (String t : s) + sb.add(UTF8Type.instance.decompose(t)); + + assertEquals(s, st.getSerializer().deserializeForNativeProtocol(CollectionSerializer.pack(sb, sb.size(), version), version)); + + // Maps + MapType<?, ?> mt = MapType.getInstance(UTF8Type.instance, LongType.instance); + Map<String, Long> m = new LinkedHashMap<>(); + m.put("bar", 12L); + m.put("foo", 42L); + m.put("zee", 14L); + + List<ByteBuffer> mb = new ArrayList<>(m.size() * 2); + for (Map.Entry<String, Long> entry : m.entrySet()) + { + mb.add(UTF8Type.instance.decompose(entry.getKey())); + mb.add(LongType.instance.decompose(entry.getValue())); + } + + assertEquals(m, mt.getSerializer().deserializeForNativeProtocol(CollectionSerializer.pack(mb, m.size(), version), version)); + } + + @Test + public void eventSerDeserTest() throws Exception + { + eventSerDeserTest(2); + eventSerDeserTest(3); + } + + public void eventSerDeserTest(int version) throws Exception + { + List<Event> events = new ArrayList<>(); + + events.add(TopologyChange.newNode(FBUtilities.getBroadcastAddress(), 42)); + events.add(TopologyChange.removedNode(FBUtilities.getBroadcastAddress(), 42)); + events.add(TopologyChange.movedNode(FBUtilities.getBroadcastAddress(), 42)); + + events.add(StatusChange.nodeUp(FBUtilities.getBroadcastAddress(), 42)); + events.add(StatusChange.nodeDown(FBUtilities.getBroadcastAddress(), 42)); + + events.add(new SchemaChange(SchemaChange.Change.CREATED, "ks")); + events.add(new SchemaChange(SchemaChange.Change.UPDATED, "ks")); + events.add(new SchemaChange(SchemaChange.Change.DROPPED, "ks")); + + events.add(new SchemaChange(SchemaChange.Change.CREATED, SchemaChange.Target.TABLE, "ks", "table")); + events.add(new SchemaChange(SchemaChange.Change.UPDATED, SchemaChange.Target.TABLE, "ks", "table")); + events.add(new SchemaChange(SchemaChange.Change.DROPPED, SchemaChange.Target.TABLE, "ks", "table")); + + if (version >= 3) + { + events.add(new SchemaChange(SchemaChange.Change.CREATED, SchemaChange.Target.TYPE, "ks", "type")); + events.add(new SchemaChange(SchemaChange.Change.UPDATED, SchemaChange.Target.TYPE, "ks", "type")); + events.add(new SchemaChange(SchemaChange.Change.DROPPED, SchemaChange.Target.TYPE, "ks", "type")); + } + + for (Event ev : events) + { + ByteBuf buf = Unpooled.buffer(ev.serializedSize(version)); + ev.serialize(buf, version); + assertEquals(ev, Event.deserialize(buf, version)); + } + } + + private static ByteBuffer bb(String str) + { + return UTF8Type.instance.decompose(str); + } + + private static ColumnIdentifier ci(String name) + { + return new ColumnIdentifier(name, false); + } + + private static Constants.Literal lit(long v) + { + return Constants.Literal.integer(String.valueOf(v)); + } + + private static Constants.Literal lit(String v) + { + return Constants.Literal.string(v); + } + + private static ColumnSpecification columnSpec(String name, AbstractType<?> type) + { + return new ColumnSpecification("ks", "cf", ci(name), type); + } + + @Test + public void udtSerDeserTest() throws Exception + { + udtSerDeserTest(2); + udtSerDeserTest(3); + } + + public void udtSerDeserTest(int version) throws Exception + { + ListType<?> lt = ListType.getInstance(Int32Type.instance); + SetType<?> st = SetType.getInstance(UTF8Type.instance); + MapType<?, ?> mt = MapType.getInstance(UTF8Type.instance, LongType.instance); + + UserType udt = new UserType("ks", + bb("myType"), + Arrays.asList(bb("f1"), bb("f2"), bb("f3"), bb("f4")), + Arrays.asList(LongType.instance, lt, st, mt)); + + Map<ColumnIdentifier, Term.Raw> value = new HashMap<>(); + value.put(ci("f1"), lit(42)); + value.put(ci("f2"), new Lists.Literal(Arrays.<Term.Raw>asList(lit(3), lit(1)))); + value.put(ci("f3"), new Sets.Literal(Arrays.<Term.Raw>asList(lit("foo"), lit("bar")))); + value.put(ci("f4"), new Maps.Literal(Arrays.<Pair<Term.Raw, Term.Raw>>asList( + Pair.<Term.Raw, Term.Raw>create(lit("foo"), lit(24)), + Pair.<Term.Raw, Term.Raw>create(lit("bar"), lit(12))))); + + UserTypes.Literal u = new UserTypes.Literal(value); + Term t = u.prepare("ks", columnSpec("myValue", udt)); + + QueryOptions options = QueryOptions.DEFAULT; + if (version == 2) + options = QueryOptions.fromProtocolV2(ConsistencyLevel.ONE, Collections.<ByteBuffer>emptyList()); + else if (version != 3) + throw new AssertionError("Invalid protocol version for test"); + + ByteBuffer serialized = t.bindAndGet(options); + + ByteBuffer[] fields = udt.split(serialized); + + assertEquals(4, fields.length); + + assertEquals(bytes(42L), fields[0]); + + // Note that no matter what the protocol version has been used in bindAndGet above, the collections inside + // a UDT should alway be serialized with version 3 of the protocol. Which is why we don't use 'version' + // on purpose below. + + assertEquals(Arrays.asList(3, 1), lt.getSerializer().deserializeForNativeProtocol(fields[1], 3)); + + LinkedHashSet<String> s = new LinkedHashSet<>(); + s.addAll(Arrays.asList("bar", "foo")); + assertEquals(s, st.getSerializer().deserializeForNativeProtocol(fields[2], 3)); + + LinkedHashMap<String, Long> m = new LinkedHashMap<>(); + m.put("bar", 12L); + m.put("foo", 24L); + assertEquals(m, mt.getSerializer().deserializeForNativeProtocol(fields[3], 3)); + } +}