http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java ---------------------------------------------------------------------- diff --git a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java index fec9251..6a17da8 100644 --- a/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/consumer/internals/FetcherTest.java @@ -886,7 +886,7 @@ public class FetcherTest { ListOffsetResponse.PartitionData partitionData = new ListOffsetResponse.PartitionData(error, timestamp, offset); Map<TopicPartition, ListOffsetResponse.PartitionData> allPartitionData = new HashMap<>(); allPartitionData.put(tp, partitionData); - return new ListOffsetResponse(allPartitionData, 1); + return new ListOffsetResponse(allPartitionData); } private FetchResponse fetchResponse(MemoryRecords records, Errors error, long hw, int throttleTime) {
http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java ---------------------------------------------------------------------- diff --git a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java index 699f6e2..d0b9639 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java @@ -30,6 +30,7 @@ import org.apache.kafka.common.record.Record; import org.junit.Test; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.nio.ByteBuffer; import java.nio.channels.GatheringByteChannel; @@ -51,133 +52,161 @@ public class RequestResponseTest { @Test public void testSerialization() throws Exception { - checkSerialization(createRequestHeader(), null); - checkSerialization(createResponseHeader(), null); - checkSerialization(createGroupCoordinatorRequest()); - checkSerialization(createGroupCoordinatorRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createGroupCoordinatorResponse(), null); - checkSerialization(createControlledShutdownRequest()); - checkSerialization(createControlledShutdownResponse(), null); - checkSerialization(createControlledShutdownRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createFetchRequest(3), 3); - checkSerialization(createFetchRequest(3).getErrorResponse(new UnknownServerException()), 3); - checkSerialization(createFetchResponse(), null); - checkSerialization(createHeartBeatRequest()); - checkSerialization(createHeartBeatRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createHeartBeatResponse(), null); - checkSerialization(createJoinGroupRequest(1), 1); - checkSerialization(createJoinGroupRequest(0).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createJoinGroupRequest(1).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createJoinGroupResponse(), null); - checkSerialization(createLeaveGroupRequest()); - checkSerialization(createLeaveGroupRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createLeaveGroupResponse(), null); - checkSerialization(createListGroupsRequest()); - checkSerialization(createListGroupsRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createListGroupsResponse(), null); - checkSerialization(createDescribeGroupRequest()); - checkSerialization(createDescribeGroupRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createDescribeGroupResponse(), null); - checkSerialization(createListOffsetRequest(1), 1); - checkSerialization(createListOffsetRequest(1).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createListOffsetResponse(1), 1); - checkSerialization(MetadataRequest.allTopics((short) 2), 2); - checkSerialization(createMetadataRequest(1, Arrays.asList("topic1")), 1); - checkSerialization(createMetadataRequest(1, Arrays.asList("topic1")).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createMetadataResponse(2), 2); - checkSerialization(createMetadataRequest(2, Arrays.asList("topic1")).getErrorResponse(new UnknownServerException()), 2); - checkSerialization(createOffsetCommitRequest(2), 2); - checkSerialization(createOffsetCommitRequest(2).getErrorResponse(new UnknownServerException()), 2); - checkSerialization(createOffsetCommitResponse(), null); - checkSerialization(OffsetFetchRequest.forAllPartitions("group1")); - checkSerialization(OffsetFetchRequest.forAllPartitions("group1").getErrorResponse(new NotCoordinatorForGroupException()), 2); - checkSerialization(createOffsetFetchRequest(0)); - checkSerialization(createOffsetFetchRequest(1)); - checkSerialization(createOffsetFetchRequest(2)); - checkSerialization(OffsetFetchRequest.forAllPartitions("group1")); - checkSerialization(createOffsetFetchRequest(0).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createOffsetFetchRequest(1).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createOffsetFetchRequest(2).getErrorResponse(new UnknownServerException()), 2); - checkSerialization(createOffsetFetchResponse(), null); - checkSerialization(createProduceRequest()); - checkSerialization(createProduceRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createProduceResponse(), null); - checkSerialization(createStopReplicaRequest(true)); - checkSerialization(createStopReplicaRequest(false)); - checkSerialization(createStopReplicaRequest(true).getErrorResponse(new UnknownServerException()), null); - checkSerialization(createStopReplicaResponse(), null); - checkSerialization(createLeaderAndIsrRequest()); - checkSerialization(createLeaderAndIsrRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createLeaderAndIsrResponse(), null); - checkSerialization(createSaslHandshakeRequest()); - checkSerialization(createSaslHandshakeRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createSaslHandshakeResponse(), null); - checkSerialization(createApiVersionRequest()); - checkSerialization(createApiVersionRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createApiVersionResponse(), null); - checkSerialization(createCreateTopicRequest(0), 0); - checkSerialization(createCreateTopicRequest(0).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createCreateTopicResponse(0), 0); - checkSerialization(createCreateTopicRequest(1), 1); - checkSerialization(createCreateTopicRequest(1).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createCreateTopicResponse(1), 1); - checkSerialization(createDeleteTopicsRequest()); - checkSerialization(createDeleteTopicsRequest().getErrorResponse(new UnknownServerException()), null); - checkSerialization(createDeleteTopicsResponse(), null); + checkRequest(createGroupCoordinatorRequest()); + checkErrorResponse(createGroupCoordinatorRequest(), new UnknownServerException()); + checkResponse(createGroupCoordinatorResponse(), 0); + checkRequest(createControlledShutdownRequest()); + checkResponse(createControlledShutdownResponse(), 1); + checkErrorResponse(createControlledShutdownRequest(), new UnknownServerException()); + checkRequest(createFetchRequest(3)); + checkErrorResponse(createFetchRequest(3), new UnknownServerException()); + checkResponse(createFetchResponse(), 0); + checkRequest(createHeartBeatRequest()); + checkErrorResponse(createHeartBeatRequest(), new UnknownServerException()); + checkResponse(createHeartBeatResponse(), 0); + checkRequest(createJoinGroupRequest(1)); + checkErrorResponse(createJoinGroupRequest(0), new UnknownServerException()); + checkErrorResponse(createJoinGroupRequest(1), new UnknownServerException()); + checkResponse(createJoinGroupResponse(), 0); + checkRequest(createLeaveGroupRequest()); + checkErrorResponse(createLeaveGroupRequest(), new UnknownServerException()); + checkResponse(createLeaveGroupResponse(), 0); + checkRequest(createListGroupsRequest()); + checkErrorResponse(createListGroupsRequest(), new UnknownServerException()); + checkResponse(createListGroupsResponse(), 0); + checkRequest(createDescribeGroupRequest()); + checkErrorResponse(createDescribeGroupRequest(), new UnknownServerException()); + checkResponse(createDescribeGroupResponse(), 0); + checkRequest(createListOffsetRequest(1)); + checkErrorResponse(createListOffsetRequest(1), new UnknownServerException()); + checkResponse(createListOffsetResponse(1), 1); + checkRequest(MetadataRequest.Builder.allTopics().build((short) 2)); + checkRequest(createMetadataRequest(1, Arrays.asList("topic1"))); + checkErrorResponse(createMetadataRequest(1, Arrays.asList("topic1")), new UnknownServerException()); + checkResponse(createMetadataResponse(), 2); + checkErrorResponse(createMetadataRequest(2, Arrays.asList("topic1")), new UnknownServerException()); + checkRequest(createOffsetCommitRequest(2)); + checkErrorResponse(createOffsetCommitRequest(2), new UnknownServerException()); + checkResponse(createOffsetCommitResponse(), 0); + checkRequest(OffsetFetchRequest.forAllPartitions("group1")); + checkErrorResponse(OffsetFetchRequest.forAllPartitions("group1"), new NotCoordinatorForGroupException()); + checkRequest(createOffsetFetchRequest(0)); + checkRequest(createOffsetFetchRequest(1)); + checkRequest(createOffsetFetchRequest(2)); + checkRequest(OffsetFetchRequest.forAllPartitions("group1")); + checkErrorResponse(createOffsetFetchRequest(0), new UnknownServerException()); + checkErrorResponse(createOffsetFetchRequest(1), new UnknownServerException()); + checkErrorResponse(createOffsetFetchRequest(2), new UnknownServerException()); + checkResponse(createOffsetFetchResponse(), 0); + checkRequest(createProduceRequest()); + checkErrorResponse(createProduceRequest(), new UnknownServerException()); + checkResponse(createProduceResponse(), 2); + checkRequest(createStopReplicaRequest(true)); + checkRequest(createStopReplicaRequest(false)); + checkErrorResponse(createStopReplicaRequest(true), new UnknownServerException()); + checkResponse(createStopReplicaResponse(), 0); + checkRequest(createLeaderAndIsrRequest()); + checkErrorResponse(createLeaderAndIsrRequest(), new UnknownServerException()); + checkResponse(createLeaderAndIsrResponse(), 0); + checkRequest(createSaslHandshakeRequest()); + checkErrorResponse(createSaslHandshakeRequest(), new UnknownServerException()); + checkResponse(createSaslHandshakeResponse(), 0); + checkRequest(createApiVersionRequest()); + checkErrorResponse(createApiVersionRequest(), new UnknownServerException()); + checkResponse(createApiVersionResponse(), 0); + checkRequest(createCreateTopicRequest(0)); + checkErrorResponse(createCreateTopicRequest(0), new UnknownServerException()); + checkResponse(createCreateTopicResponse(), 0); + checkRequest(createCreateTopicRequest(1)); + checkErrorResponse(createCreateTopicRequest(1), new UnknownServerException()); + checkResponse(createCreateTopicResponse(), 1); + checkRequest(createDeleteTopicsRequest()); + checkErrorResponse(createDeleteTopicsRequest(), new UnknownServerException()); + checkResponse(createDeleteTopicsResponse(), 0); checkOlderFetchVersions(); - checkSerialization(createMetadataResponse(0), 0); - checkSerialization(createMetadataResponse(1), 1); - checkSerialization(createMetadataRequest(1, Arrays.asList("topic1")).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createOffsetCommitRequest(0), 0); - checkSerialization(createOffsetCommitRequest(0).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createOffsetCommitRequest(1), 1); - checkSerialization(createOffsetCommitRequest(1).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createJoinGroupRequest(0), 0); - checkSerialization(createUpdateMetadataRequest(0, null), 0); - checkSerialization(createUpdateMetadataRequest(0, null).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createUpdateMetadataRequest(1, null), 1); - checkSerialization(createUpdateMetadataRequest(1, "rack1"), 1); - checkSerialization(createUpdateMetadataRequest(1, null).getErrorResponse(new UnknownServerException()), 1); - checkSerialization(createUpdateMetadataRequest(2, "rack1"), 2); - checkSerialization(createUpdateMetadataRequest(2, null), 2); - checkSerialization(createUpdateMetadataRequest(2, "rack1").getErrorResponse(new UnknownServerException()), 2); - checkSerialization(createUpdateMetadataRequest(3, "rack1")); - checkSerialization(createUpdateMetadataRequest(3, null)); - checkSerialization(createUpdateMetadataRequest(3, "rack1").getErrorResponse(new UnknownServerException()), 3); - checkSerialization(createUpdateMetadataResponse(), null); - checkSerialization(createListOffsetRequest(0), 0); - checkSerialization(createListOffsetRequest(0).getErrorResponse(new UnknownServerException()), 0); - checkSerialization(createListOffsetResponse(0), 0); + checkResponse(createMetadataResponse(), 0); + checkResponse(createMetadataResponse(), 1); + checkErrorResponse(createMetadataRequest(1, Arrays.asList("topic1")), new UnknownServerException()); + checkRequest(createOffsetCommitRequest(0)); + checkErrorResponse(createOffsetCommitRequest(0), new UnknownServerException()); + checkRequest(createOffsetCommitRequest(1)); + checkErrorResponse(createOffsetCommitRequest(1), new UnknownServerException()); + checkRequest(createJoinGroupRequest(0)); + checkRequest(createUpdateMetadataRequest(0, null)); + checkErrorResponse(createUpdateMetadataRequest(0, null), new UnknownServerException()); + checkRequest(createUpdateMetadataRequest(1, null)); + checkRequest(createUpdateMetadataRequest(1, "rack1")); + checkErrorResponse(createUpdateMetadataRequest(1, null), new UnknownServerException()); + checkRequest(createUpdateMetadataRequest(2, "rack1")); + checkRequest(createUpdateMetadataRequest(2, null)); + checkErrorResponse(createUpdateMetadataRequest(2, "rack1"), new UnknownServerException()); + checkRequest(createUpdateMetadataRequest(3, "rack1")); + checkRequest(createUpdateMetadataRequest(3, null)); + checkErrorResponse(createUpdateMetadataRequest(3, "rack1"), new UnknownServerException()); + checkResponse(createUpdateMetadataResponse(), 0); + checkRequest(createListOffsetRequest(0)); + checkErrorResponse(createListOffsetRequest(0), new UnknownServerException()); + checkResponse(createListOffsetResponse(0), 0); + } + + @Test + public void testRequestHeader() { + RequestHeader header = createRequestHeader(); + ByteBuffer buffer = toBuffer(header.toStruct()); + RequestHeader deserialized = RequestHeader.parse(buffer); + assertEquals(header.apiVersion(), deserialized.apiVersion()); + assertEquals(header.apiKey(), deserialized.apiKey()); + assertEquals(header.clientId(), deserialized.clientId()); + assertEquals(header.correlationId(), deserialized.correlationId()); + } + + @Test + public void testResponseHeader() { + ResponseHeader header = createResponseHeader(); + ByteBuffer buffer = toBuffer(header.toStruct()); + ResponseHeader deserialized = ResponseHeader.parse(buffer); + assertEquals(header.correlationId(), deserialized.correlationId()); } private void checkOlderFetchVersions() throws Exception { int latestVersion = ProtoUtils.latestVersion(ApiKeys.FETCH.id); for (int i = 0; i < latestVersion; ++i) { - checkSerialization(createFetchRequest(i).getErrorResponse(new UnknownServerException()), i); - checkSerialization(createFetchRequest(i), i); + checkErrorResponse(createFetchRequest(i), new UnknownServerException()); + checkRequest(createFetchRequest(i)); } } - private void checkSerialization(AbstractRequest req) throws Exception { - checkSerialization(req, Integer.valueOf(req.version())); + private void checkErrorResponse(AbstractRequest req, Throwable e) throws Exception { + checkResponse(req.getErrorResponse(e), req.version()); + } + + private void checkRequest(AbstractRequest req) throws Exception { + // Check that we can serialize, deserialize and serialize again + // We don't check for equality or hashCode because it is likely to fail for any request containing a HashMap + Struct struct = req.toStruct(); + AbstractRequest deserialized = (AbstractRequest) deserialize(req, struct, req.version()); + deserialized.toStruct(); } - private void checkSerialization(AbstractRequestResponse req, Integer version) throws Exception { - ByteBuffer buffer = ByteBuffer.allocate(req.sizeOf()); - req.writeTo(buffer); + private void checkResponse(AbstractResponse response, int version) throws Exception { + // Check that we can serialize, deserialize and serialize again + // We don't check for equality or hashCode because it is likely to fail for any response containing a HashMap + Struct struct = response.toStruct((short) version); + AbstractResponse deserialized = (AbstractResponse) deserialize(response, struct, (short) version); + deserialized.toStruct((short) version); + } + + private AbstractRequestResponse deserialize(AbstractRequestResponse req, Struct struct, short version) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { + ByteBuffer buffer = toBuffer(struct); + Method deserializer = req.getClass().getDeclaredMethod("parse", ByteBuffer.class, Short.TYPE); + return (AbstractRequestResponse) deserializer.invoke(null, buffer, version); + } + + private ByteBuffer toBuffer(Struct struct) { + ByteBuffer buffer = ByteBuffer.allocate(struct.sizeOf()); + struct.writeTo(buffer); buffer.rewind(); - AbstractRequestResponse deserialized; - if (version == null) { - Method deserializer = req.getClass().getDeclaredMethod("parse", ByteBuffer.class); - deserialized = (AbstractRequestResponse) deserializer.invoke(null, buffer); - } else { - Method deserializer = req.getClass().getDeclaredMethod("parse", ByteBuffer.class, Integer.TYPE); - deserialized = (AbstractRequestResponse) deserializer.invoke(null, buffer, version); - } - assertEquals("The original and deserialized of " + req.getClass().getSimpleName() + - "(version " + version + ") should be the same.", req, deserialized); - assertEquals("The original and deserialized of " + req.getClass().getSimpleName() + " should have the same hashcode.", - req.hashCode(), deserialized.hashCode()); + return buffer; } @Test @@ -186,14 +215,17 @@ public class RequestResponseTest { responseData.put(new TopicPartition("test", 0), new ProduceResponse.PartitionResponse(Errors.NONE, 10000, Record.NO_TIMESTAMP)); ProduceResponse v0Response = new ProduceResponse(responseData); - ProduceResponse v1Response = new ProduceResponse(responseData, 10, 1); - ProduceResponse v2Response = new ProduceResponse(responseData, 10, 2); + ProduceResponse v1Response = new ProduceResponse(responseData, 10); + ProduceResponse v2Response = new ProduceResponse(responseData, 10); assertEquals("Throttle time must be zero", 0, v0Response.getThrottleTime()); assertEquals("Throttle time must be 10", 10, v1Response.getThrottleTime()); assertEquals("Throttle time must be 10", 10, v2Response.getThrottleTime()); - assertEquals("Should use schema version 0", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 0), v0Response.toStruct().schema()); - assertEquals("Should use schema version 1", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 1), v1Response.toStruct().schema()); - assertEquals("Should use schema version 2", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 2), v2Response.toStruct().schema()); + assertEquals("Should use schema version 0", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 0), + v0Response.toStruct((short) 0).schema()); + assertEquals("Should use schema version 1", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 1), + v1Response.toStruct((short) 1).schema()); + assertEquals("Should use schema version 2", ProtoUtils.responseSchema(ApiKeys.PRODUCE.id, 2), + v2Response.toStruct((short) 2).schema()); assertEquals("Response data does not match", responseData, v0Response.responses()); assertEquals("Response data does not match", responseData, v1Response.responses()); assertEquals("Response data does not match", responseData, v2Response.responses()); @@ -206,12 +238,14 @@ public class RequestResponseTest { MemoryRecords records = MemoryRecords.readableRecords(ByteBuffer.allocate(10)); responseData.put(new TopicPartition("test", 0), new FetchResponse.PartitionData(Errors.NONE, 1000000, records)); - FetchResponse v0Response = new FetchResponse(0, responseData, 0); - FetchResponse v1Response = new FetchResponse(1, responseData, 10); - assertEquals("Throttle time must be zero", 0, v0Response.getThrottleTime()); - assertEquals("Throttle time must be 10", 10, v1Response.getThrottleTime()); - assertEquals("Should use schema version 0", ProtoUtils.responseSchema(ApiKeys.FETCH.id, 0), v0Response.toStruct().schema()); - assertEquals("Should use schema version 1", ProtoUtils.responseSchema(ApiKeys.FETCH.id, 1), v1Response.toStruct().schema()); + FetchResponse v0Response = new FetchResponse(responseData, 0); + FetchResponse v1Response = new FetchResponse(responseData, 10); + assertEquals("Throttle time must be zero", 0, v0Response.throttleTimeMs()); + assertEquals("Throttle time must be 10", 10, v1Response.throttleTimeMs()); + assertEquals("Should use schema version 0", ProtoUtils.responseSchema(ApiKeys.FETCH.id, 0), + v0Response.toStruct((short) 0).schema()); + assertEquals("Should use schema version 1", ProtoUtils.responseSchema(ApiKeys.FETCH.id, 1), + v1Response.toStruct((short) 1).schema()); assertEquals("Response data does not match", responseData, v0Response.responseData()); assertEquals("Response data does not match", responseData, v1Response.responseData()); } @@ -239,19 +273,18 @@ public class RequestResponseTest { // read the body Struct responseBody = ProtoUtils.responseSchema(ApiKeys.FETCH.id, header.apiVersion()).read(buf); - FetchResponse parsedResponse = new FetchResponse(responseBody); - assertEquals(parsedResponse, fetchResponse); + assertEquals(fetchResponse.toStruct(header.apiVersion()), responseBody); - assertEquals(size, responseHeader.sizeOf() + parsedResponse.sizeOf()); + assertEquals(size, responseHeader.sizeOf() + responseBody.sizeOf()); } @Test public void testControlledShutdownResponse() { ControlledShutdownResponse response = createControlledShutdownResponse(); - ByteBuffer buffer = ByteBuffer.allocate(response.sizeOf()); - response.writeTo(buffer); - buffer.rewind(); - ControlledShutdownResponse deserialized = ControlledShutdownResponse.parse(buffer); + short version = ProtoUtils.latestVersion(ApiKeys.CONTROLLED_SHUTDOWN_KEY.id); + Struct struct = response.toStruct(version); + ByteBuffer buffer = toBuffer(struct); + ControlledShutdownResponse deserialized = ControlledShutdownResponse.parse(buffer, version); assertEquals(response.error(), deserialized.error()); assertEquals(response.partitionsRemaining(), deserialized.partitionsRemaining()); } @@ -259,9 +292,8 @@ public class RequestResponseTest { @Test public void testRequestHeaderWithNullClientId() { RequestHeader header = new RequestHeader((short) 10, (short) 1, null, 10); - ByteBuffer buffer = ByteBuffer.allocate(header.sizeOf()); - header.writeTo(buffer); - buffer.rewind(); + Struct headerStruct = header.toStruct(); + ByteBuffer buffer = toBuffer(headerStruct); RequestHeader deserialized = RequestHeader.parse(buffer); assertEquals(header.apiKey(), deserialized.apiKey()); assertEquals(header.apiVersion(), deserialized.apiVersion()); @@ -294,8 +326,7 @@ public class RequestResponseTest { LinkedHashMap<TopicPartition, FetchRequest.PartitionData> fetchData = new LinkedHashMap<>(); fetchData.put(new TopicPartition("test1", 0), new FetchRequest.PartitionData(100, 1000000)); fetchData.put(new TopicPartition("test2", 0), new FetchRequest.PartitionData(200, 1000000)); - return new FetchRequest.Builder(100, 100000, fetchData).setMaxBytes(1000). - setVersion((short) version).build(); + return FetchRequest.Builder.forConsumer(100, 100000, fetchData).setMaxBytes(1000).build((short) version); } private FetchResponse createFetchResponse() { @@ -320,7 +351,7 @@ public class RequestResponseTest { protocols.add(new JoinGroupRequest.ProtocolMetadata("consumer-range", metadata)); if (version == 0) { return new JoinGroupRequest.Builder("group1", 30000, "consumer1", "consumer", protocols). - setVersion((short) version).build(); + build((short) version); } else { return new JoinGroupRequest.Builder("group1", 10000, "consumer1", "consumer", protocols). setRebalanceTimeout(60000).build(); @@ -372,11 +403,11 @@ public class RequestResponseTest { Map<TopicPartition, ListOffsetRequest.PartitionData> offsetData = Collections.singletonMap( new TopicPartition("test", 0), new ListOffsetRequest.PartitionData(1000000L, 10)); - return new ListOffsetRequest.Builder().setOffsetData(offsetData).setVersion((short) version).build(); + return ListOffsetRequest.Builder.forConsumer((short) 0).setOffsetData(offsetData).build((short) version); } else if (version == 1) { Map<TopicPartition, Long> offsetData = Collections.singletonMap( new TopicPartition("test", 0), 1000000L); - return new ListOffsetRequest.Builder().setTargetTimes(offsetData).setVersion((short) version).build(); + return ListOffsetRequest.Builder.forConsumer((short) 1).setTargetTimes(offsetData).build((short) version); } else { throw new IllegalArgumentException("Illegal ListOffsetRequest version " + version); } @@ -386,24 +417,24 @@ public class RequestResponseTest { private ListOffsetResponse createListOffsetResponse(int version) { if (version == 0) { Map<TopicPartition, ListOffsetResponse.PartitionData> responseData = new HashMap<>(); - responseData.put(new TopicPartition("test", 0), new ListOffsetResponse.PartitionData(Errors.NONE, Arrays.asList(100L))); + responseData.put(new TopicPartition("test", 0), + new ListOffsetResponse.PartitionData(Errors.NONE, Arrays.asList(100L))); return new ListOffsetResponse(responseData); } else if (version == 1) { Map<TopicPartition, ListOffsetResponse.PartitionData> responseData = new HashMap<>(); - responseData.put(new TopicPartition("test", 0), new ListOffsetResponse.PartitionData(Errors.NONE, 10000L, 100L)); - return new ListOffsetResponse(responseData, 1); + responseData.put(new TopicPartition("test", 0), + new ListOffsetResponse.PartitionData(Errors.NONE, 10000L, 100L)); + return new ListOffsetResponse(responseData); } else { throw new IllegalArgumentException("Illegal ListOffsetResponse version " + version); } } private MetadataRequest createMetadataRequest(int version, List<String> topics) { - return new MetadataRequest.Builder(topics). - setVersion((short) version). - build(); + return new MetadataRequest.Builder(topics).build((short) version); } - private MetadataResponse createMetadataResponse(int version) { + private MetadataResponse createMetadataResponse() { Node node = new Node(1, "host1", 1001); List<Node> replicas = Arrays.asList(node); List<Node> isr = Arrays.asList(node); @@ -414,7 +445,7 @@ public class RequestResponseTest { allTopicMetadata.add(new MetadataResponse.TopicMetadata(Errors.LEADER_NOT_AVAILABLE, "topic2", false, Collections.<MetadataResponse.PartitionMetadata>emptyList())); - return new MetadataResponse(Arrays.asList(node), null, MetadataResponse.NO_CONTROLLER_ID, allTopicMetadata, version); + return new MetadataResponse(Arrays.asList(node), null, MetadataResponse.NO_CONTROLLER_ID, allTopicMetadata); } private OffsetCommitRequest createOffsetCommitRequest(int version) { @@ -425,8 +456,7 @@ public class RequestResponseTest { .setGenerationId(100) .setMemberId("consumer1") .setRetentionTime(1000000) - .setVersion((short) version) - .build(); + .build((short) version); } private OffsetCommitResponse createOffsetCommitResponse() { @@ -437,8 +467,7 @@ public class RequestResponseTest { private OffsetFetchRequest createOffsetFetchRequest(int version) { return new OffsetFetchRequest.Builder("group1", singletonList(new TopicPartition("test11", 1))) - .setVersion((short) version) - .build(); + .build((short) version); } private OffsetFetchResponse createOffsetFetchResponse() { @@ -540,8 +569,8 @@ public class RequestResponseTest { new UpdateMetadataRequest.Broker(0, endPoints1, rack), new UpdateMetadataRequest.Broker(1, endPoints2, rack) )); - return new UpdateMetadataRequest.Builder(1, 10, partitionStates, liveBrokers). - setVersion((short) version).build(); + return new UpdateMetadataRequest.Builder((short) version, 1, 10, partitionStates, + liveBrokers).build(); } private UpdateMetadataResponse createUpdateMetadataResponse() { @@ -584,14 +613,14 @@ public class RequestResponseTest { Map<String, CreateTopicsRequest.TopicDetails> request = new HashMap<>(); request.put("my_t1", request1); request.put("my_t2", request2); - return new CreateTopicsRequest.Builder(request, 0, validateOnly).setVersion((short) version).build(); + return new CreateTopicsRequest.Builder(request, 0, validateOnly).build((short) version); } - private CreateTopicsResponse createCreateTopicResponse(int version) { + private CreateTopicsResponse createCreateTopicResponse() { Map<String, CreateTopicsResponse.Error> errors = new HashMap<>(); errors.put("t1", new CreateTopicsResponse.Error(Errors.INVALID_TOPIC_EXCEPTION, null)); errors.put("t2", new CreateTopicsResponse.Error(Errors.LEADER_NOT_AVAILABLE, "Leader with id 5 is not available.")); - return new CreateTopicsResponse(errors, (short) version); + return new CreateTopicsResponse(errors); } private DeleteTopicsRequest createDeleteTopicsRequest() { http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java ---------------------------------------------------------------------- diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java index 76fb9b3..3a9e0ce 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java @@ -392,7 +392,7 @@ public class SaslAuthenticatorTest { selector.send(request.toSend(node, header)); ByteBuffer responseBuffer = waitForResponse(); ResponseHeader.parse(responseBuffer); - ApiVersionsResponse response = ApiVersionsResponse.parse(responseBuffer); + ApiVersionsResponse response = ApiVersionsResponse.parse(responseBuffer, (short) 0); assertEquals(Errors.UNSUPPORTED_VERSION, response.error()); // Send ApiVersionsRequest with a supported version. This should succeed. http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/api/FetchRequest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/api/FetchRequest.scala b/core/src/main/scala/kafka/api/FetchRequest.scala index 97da1f5..f049821 100644 --- a/core/src/main/scala/kafka/api/FetchRequest.scala +++ b/core/src/main/scala/kafka/api/FetchRequest.scala @@ -205,7 +205,7 @@ case class FetchRequest(versionId: Short = FetchRequest.CurrentVersion, responseData.put(new TopicPartition(topic, partition), new JFetchResponse.PartitionData(Errors.forException(e), -1, MemoryRecords.EMPTY)) } - val errorResponse = new JFetchResponse(versionId, responseData, 0) + val errorResponse = new JFetchResponse(responseData, 0) // Magic value does not matter here because the message set is empty requestChannel.sendResponse(new RequestChannel.Response(request, errorResponse)) } http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/api/GenericRequestAndHeader.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/api/GenericRequestAndHeader.scala b/core/src/main/scala/kafka/api/GenericRequestAndHeader.scala deleted file mode 100644 index 3783c29..0000000 --- a/core/src/main/scala/kafka/api/GenericRequestAndHeader.scala +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE - * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file - * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the - * License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package kafka.api - -import java.nio.ByteBuffer - -import kafka.api.ApiUtils._ -import org.apache.kafka.common.requests.AbstractResponse - -private[kafka] abstract class GenericRequestAndHeader(val versionId: Short, - val correlationId: Int, - val clientId: String, - val body: AbstractResponse, - val name: String, - override val requestId: Option[Short] = None) - extends RequestOrResponse(requestId) { - - def writeTo(buffer: ByteBuffer) { - buffer.putShort(versionId) - buffer.putInt(correlationId) - writeShortString(buffer, clientId) - body.writeTo(buffer) - } - - def sizeInBytes(): Int = { - 2 /* version id */ + - 4 /* correlation id */ + - (2 + clientId.length) /* client id */ + - body.sizeOf() - } - - override def toString: String = { - describe(true) - } - - override def describe(details: Boolean): String = { - val strBuffer = new StringBuilder - strBuffer.append("Name: " + name) - strBuffer.append("; Version: " + versionId) - strBuffer.append("; CorrelationId: " + correlationId) - strBuffer.append("; ClientId: " + clientId) - strBuffer.append("; Body: " + body.toString) - strBuffer.toString() - } -} http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/api/GenericResponseAndHeader.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/api/GenericResponseAndHeader.scala b/core/src/main/scala/kafka/api/GenericResponseAndHeader.scala deleted file mode 100644 index be0c080..0000000 --- a/core/src/main/scala/kafka/api/GenericResponseAndHeader.scala +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE - * file distributed with this work for additional information regarding copyright ownership. The ASF licenses this file - * to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the - * License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on - * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the - * specific language governing permissions and limitations under the License. - */ - -package kafka.api - -import java.nio.ByteBuffer - -import org.apache.kafka.common.requests.AbstractResponse - -private[kafka] abstract class GenericResponseAndHeader(val correlationId: Int, - val body: AbstractResponse, - val name: String, - override val requestId: Option[Short] = None) - extends RequestOrResponse(requestId) { - - def writeTo(buffer: ByteBuffer) { - buffer.putInt(correlationId) - body.writeTo(buffer) - } - - def sizeInBytes(): Int = { - 4 /* correlation id */ + - body.sizeOf() - } - - override def toString: String = { - describe(true) - } - - override def describe(details: Boolean): String = { - val strBuffer = new StringBuilder - strBuffer.append("Name: " + name) - strBuffer.append("; CorrelationId: " + correlationId) - strBuffer.append("; Body: " + body.toString) - strBuffer.toString() - } -} http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/api/ProducerRequest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/api/ProducerRequest.scala b/core/src/main/scala/kafka/api/ProducerRequest.scala index a87cdc9..bd48388 100644 --- a/core/src/main/scala/kafka/api/ProducerRequest.scala +++ b/core/src/main/scala/kafka/api/ProducerRequest.scala @@ -129,7 +129,7 @@ case class ProducerRequest(versionId: Short = ProducerRequest.CurrentVersion, } override def handleError(e: Throwable, requestChannel: RequestChannel, request: RequestChannel.Request): Unit = { - if(request.body.asInstanceOf[org.apache.kafka.common.requests.ProduceRequest].acks == 0) { + if (request.body[org.apache.kafka.common.requests.ProduceRequest].acks == 0) { requestChannel.closeConnection(request.processor, request) } else { http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/controller/ControllerChannelManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala index d8e6a95..a2fee6b 100755 --- a/core/src/main/scala/kafka/controller/ControllerChannelManager.scala +++ b/core/src/main/scala/kafka/controller/ControllerChannelManager.scala @@ -364,9 +364,9 @@ class ControllerBrokerRequestBatch(controller: KafkaController) extends Logging partitionStateInfo.allReplicas.map(Integer.valueOf).asJava) topicPartition -> partitionState } - val leaderAndIsrRequest = new LeaderAndIsrRequest. - Builder(controllerId, controllerEpoch, partitionStates.asJava, leaders.asJava) - controller.sendRequest(broker, ApiKeys.LEADER_AND_ISR, leaderAndIsrRequest, null) + val leaderAndIsrRequest = new LeaderAndIsrRequest.Builder(controllerId, controllerEpoch, partitionStates.asJava, + leaders.asJava) + controller.sendRequest(broker, ApiKeys.LEADER_AND_ISR, leaderAndIsrRequest) } leaderAndIsrRequestMap.clear() @@ -405,9 +405,8 @@ class ControllerBrokerRequestBatch(controller: KafkaController) extends Logging new UpdateMetadataRequest.Broker(broker.id, endPoints.asJava, broker.rack.orNull) } } - new UpdateMetadataRequest.Builder( - controllerId, controllerEpoch, partitionStates.asJava, liveBrokers.asJava). - setVersion(version) + new UpdateMetadataRequest.Builder(version, controllerId, controllerEpoch, partitionStates.asJava, + liveBrokers.asJava) } updateMetadataRequestBrokerSet.foreach { broker => http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/network/RequestChannel.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala index c063801..a5d7160 100644 --- a/core/src/main/scala/kafka/network/RequestChannel.scala +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -19,7 +19,7 @@ package kafka.network import java.net.InetAddress import java.nio.ByteBuffer -import java.util.HashMap +import java.util.Collections import java.util.concurrent._ import com.yammer.metrics.core.Gauge @@ -37,16 +37,18 @@ import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.utils.Time import org.apache.log4j.Logger +import scala.reflect.{classTag, ClassTag} + object RequestChannel extends Logging { val AllDone = Request(processor = 1, connectionId = "2", Session(KafkaPrincipal.ANONYMOUS, InetAddress.getLocalHost), - buffer = getShutdownReceive, startTimeMs = 0, listenerName = new ListenerName(""), + buffer = shutdownReceive, startTimeMs = 0, listenerName = new ListenerName(""), securityProtocol = SecurityProtocol.PLAINTEXT) private val requestLogger = Logger.getLogger("kafka.request.logger") - private def getShutdownReceive = { - val emptyProduceRequest = new ProduceRequest.Builder(0, 0, new HashMap[TopicPartition, MemoryRecords]()).build() + private def shutdownReceive: ByteBuffer = { + val emptyProduceRequest = new ProduceRequest.Builder(0, 0, Collections.emptyMap[TopicPartition, MemoryRecords]).build() val emptyRequestHeader = new RequestHeader(ApiKeys.PRODUCE.id, emptyProduceRequest.version, "", 0) - AbstractRequestResponse.serialize(emptyRequestHeader, emptyProduceRequest) + emptyProduceRequest.serialize(emptyRequestHeader) } case class Session(principal: KafkaPrincipal, clientAddress: InetAddress) { @@ -84,12 +86,13 @@ object RequestChannel extends Logging { } } else null - val body: AbstractRequest = + val bodyAndSize: RequestAndSize = if (requestObj == null) try { // For unsupported version of ApiVersionsRequest, create a dummy request to enable an error response to be returned later - if (header.apiKey == ApiKeys.API_VERSIONS.id && !Protocol.apiVersionSupported(header.apiKey, header.apiVersion)) - new ApiVersionsRequest.Builder().build() + if (header.apiKey == ApiKeys.API_VERSIONS.id && !Protocol.apiVersionSupported(header.apiKey, header.apiVersion)) { + new RequestAndSize(new ApiVersionsRequest.Builder().build(), 0) + } else AbstractRequest.getRequest(header.apiKey, header.apiVersion, buffer) } catch { @@ -108,6 +111,14 @@ object RequestChannel extends Logging { s"$header -- $body" } + def body[T <: AbstractRequest : ClassTag] = { + bodyAndSize.request match { + case r: T => r + case r => + throw new ClassCastException(s"Expected request with type ${classTag[T].runtimeClass}, but found ${r.getClass}") + } + } + trace("Processor %d received request : %s".format(processor, requestDesc(true))) def updateRequestMetrics() { @@ -132,7 +143,7 @@ object RequestChannel extends Logging { val totalTime = endTimeMs - startTimeMs val fetchMetricNames = if (requestId == ApiKeys.FETCH.id) { - val isFromFollower = body.asInstanceOf[FetchRequest].isFromFollower + val isFromFollower = body[FetchRequest].isFromFollower Seq( if (isFromFollower) RequestMetrics.followFetchMetricName else RequestMetrics.consumerFetchMetricName @@ -163,11 +174,8 @@ object RequestChannel extends Logging { case class Response(processor: Int, request: Request, responseSend: Send, responseAction: ResponseAction) { request.responseCompleteTimeMs = Time.SYSTEM.milliseconds - def this(processor: Int, request: Request, responseSend: Send) = - this(processor, request, responseSend, if (responseSend == null) NoOpAction else SendAction) - - def this(request: Request, send: Send) = - this(request.processor, request, send) + def this(request: Request, responseSend: Send) = + this(request.processor, request, responseSend, if (responseSend == null) NoOpAction else SendAction) def this(request: Request, response: AbstractResponse) = this(request, response.toSend(request.connectionId, request.header)) http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/server/KafkaApis.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 1308216..c90cace 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -108,7 +108,7 @@ class KafkaApis(val requestChannel: RequestChannel, request.requestObj.handleError(e, requestChannel, request) error("Error when handling request %s".format(request.requestObj), e) } else { - val response = request.body.getErrorResponse(e) + val response = request.body[AbstractRequest].getErrorResponse(e) /* If request doesn't have a default error response, we just close the connection. For example, when produce request has acks set to 0 */ @@ -128,7 +128,7 @@ class KafkaApis(val requestChannel: RequestChannel, // We can't have the ensureTopicExists check here since the controller sends it as an advisory to all brokers so they // stop serving data to clients for the topic being deleted val correlationId = request.header.correlationId - val leaderAndIsrRequest = request.body.asInstanceOf[LeaderAndIsrRequest] + val leaderAndIsrRequest = request.body[LeaderAndIsrRequest] try { def onLeadershipChange(updatedLeaders: Iterable[Partition], updatedFollowers: Iterable[Partition]) { @@ -167,7 +167,7 @@ class KafkaApis(val requestChannel: RequestChannel, // ensureTopicExists is only for client facing requests // We can't have the ensureTopicExists check here since the controller sends it as an advisory to all brokers so they // stop serving data to clients for the topic being deleted - val stopReplicaRequest = request.body.asInstanceOf[StopReplicaRequest] + val stopReplicaRequest = request.body[StopReplicaRequest] val response = if (authorize(request.session, ClusterAction, Resource.ClusterResource)) { @@ -195,7 +195,7 @@ class KafkaApis(val requestChannel: RequestChannel, def handleUpdateMetadataRequest(request: RequestChannel.Request) { val correlationId = request.header.correlationId - val updateMetadataRequest = request.body.asInstanceOf[UpdateMetadataRequest] + val updateMetadataRequest = request.body[UpdateMetadataRequest] val updateMetadataResponse = if (authorize(request.session, ClusterAction, Resource.ClusterResource)) { @@ -235,7 +235,7 @@ class KafkaApis(val requestChannel: RequestChannel, */ def handleOffsetCommitRequest(request: RequestChannel.Request) { val header = request.header - val offsetCommitRequest = request.body.asInstanceOf[OffsetCommitRequest] + val offsetCommitRequest = request.body[OffsetCommitRequest] // reject the request if not authorized to the group if (!authorize(request.session, Read, new Resource(Group, offsetCommitRequest.groupId))) { @@ -247,14 +247,13 @@ class KafkaApis(val requestChannel: RequestChannel, requestChannel.sendResponse(new RequestChannel.Response(request, response)) } else { val (existingAndAuthorizedForDescribeTopics, nonExistingOrUnauthorizedForDescribeTopics) = offsetCommitRequest.offsetData.asScala.toMap.partition { - case (topicPartition, _) => { + case (topicPartition, _) => val authorizedForDescribe = authorize(request.session, Describe, new Resource(auth.Topic, topicPartition.topic)) val exists = metadataCache.contains(topicPartition.topic) if (!authorizedForDescribe && exists) debug(s"Offset commit request with correlation id ${header.correlationId} from client ${header.clientId} " + s"on partition $topicPartition failing due to user not having DESCRIBE authorization, but returning UNKNOWN_TOPIC_OR_PARTITION") authorizedForDescribe && exists - } } val (authorizedTopics, unauthorizedForReadTopics) = existingAndAuthorizedForDescribeTopics.partition { @@ -349,8 +348,8 @@ class KafkaApis(val requestChannel: RequestChannel, * Handle a produce request */ def handleProducerRequest(request: RequestChannel.Request) { - val produceRequest = request.body.asInstanceOf[ProduceRequest] - val numBytesAppended = request.header.sizeOf + produceRequest.sizeOf + val produceRequest = request.body[ProduceRequest] + val numBytesAppended = request.header.toStruct.sizeOf + request.bodyAndSize.size val (existingAndAuthorizedForDescribeTopics, nonExistingOrUnauthorizedForDescribeTopics) = produceRequest.partitionRecords.asScala.partition { case (topicPartition, _) => authorize(request.session, Describe, new Resource(auth.Topic, topicPartition.topic)) && metadataCache.contains(topicPartition.topic) @@ -399,14 +398,7 @@ class KafkaApis(val requestChannel: RequestChannel, requestChannel.noOperation(request.processor, request) } } else { - val respBody = request.header.apiVersion match { - case 0 => new ProduceResponse(mergedResponseStatus.asJava) - case version@(1 | 2) => new ProduceResponse(mergedResponseStatus.asJava, delayTimeMs, version) - // This case shouldn't happen unless a new version of ProducerRequest is added without - // updating this part of the code to handle it properly. - case version => throw new IllegalArgumentException(s"Version `$version` of ProduceRequest is not handled. Code must be updated.") - } - + val respBody = new ProduceResponse(mergedResponseStatus.asJava, delayTimeMs) requestChannel.sendResponse(new RequestChannel.Response(request, respBody)) } } @@ -445,7 +437,7 @@ class KafkaApis(val requestChannel: RequestChannel, * Handle a fetch request */ def handleFetchRequest(request: RequestChannel.Request) { - val fetchRequest = request.body.asInstanceOf[FetchRequest] + val fetchRequest = request.body[FetchRequest] val versionId = request.header.apiVersion val clientId = request.header.clientId @@ -505,13 +497,13 @@ class KafkaApis(val requestChannel: RequestChannel, BrokerTopicStats.getBrokerAllTopicsStats().bytesOutRate.mark(data.records.sizeInBytes) } - val response = new FetchResponse(versionId, fetchedPartitionData, 0) + val response = new FetchResponse(fetchedPartitionData, 0) + val responseStruct = response.toStruct(versionId) - def fetchResponseCallback(delayTimeMs: Int) { - trace(s"Sending fetch response to client $clientId of " + - s"${convertedPartitionData.map { case (_, v) => v.records.sizeInBytes }.sum} bytes") - val fetchResponse = if (delayTimeMs > 0) new FetchResponse(versionId, fetchedPartitionData, delayTimeMs) else response - requestChannel.sendResponse(new RequestChannel.Response(request, fetchResponse)) + def fetchResponseCallback(throttleTimeMs: Int) { + trace(s"Sending fetch response to client $clientId of ${responseStruct.sizeOf} bytes.") + val responseSend = response.toSend(responseStruct, throttleTimeMs, request.connectionId, request.header) + requestChannel.sendResponse(new RequestChannel.Response(request, responseSend)) } // When this callback is triggered, the remote API call has completed @@ -521,9 +513,10 @@ class KafkaApis(val requestChannel: RequestChannel, // We've already evaluated against the quota and are good to go. Just need to record it now. val responseSize = sizeOfThrottledPartitions(versionId, fetchRequest, mergedPartitionData, quotas.leader) quotas.leader.record(responseSize) - fetchResponseCallback(0) + fetchResponseCallback(throttleTimeMs = 0) } else { - quotas.fetch.recordAndMaybeThrottle(request.session.sanitizedUser, clientId, response.sizeOf, fetchResponseCallback) + quotas.fetch.recordAndMaybeThrottle(request.session.sanitizedUser, clientId, responseStruct.sizeOf, + fetchResponseCallback) } } @@ -547,7 +540,7 @@ class KafkaApis(val requestChannel: RequestChannel, fetchRequest: FetchRequest, mergedPartitionData: Seq[(TopicPartition, FetchResponse.PartitionData)], quota: ReplicationQuotaManager): Int = { - val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData]() + val partitionData = new util.LinkedHashMap[TopicPartition, FetchResponse.PartitionData] mergedPartitionData.foreach { case (tp, data) => if (quota.isThrottled(tp)) partitionData.put(tp, data) @@ -570,14 +563,14 @@ class KafkaApis(val requestChannel: RequestChannel, else handleListOffsetRequestV1(request) - val response = new ListOffsetResponse(mergedResponseMap.asJava, version) + val response = new ListOffsetResponse(mergedResponseMap.asJava) requestChannel.sendResponse(new RequestChannel.Response(request, response)) } private def handleListOffsetRequestV0(request : RequestChannel.Request) : Map[TopicPartition, ListOffsetResponse.PartitionData] = { val correlationId = request.header.correlationId val clientId = request.header.clientId - val offsetRequest = request.body.asInstanceOf[ListOffsetRequest] + val offsetRequest = request.body[ListOffsetRequest] val (authorizedRequestInfo, unauthorizedRequestInfo) = offsetRequest.offsetData.asScala.partition { case (topicPartition, _) => authorize(request.session, Describe, new Resource(auth.Topic, topicPartition.topic)) @@ -628,7 +621,7 @@ class KafkaApis(val requestChannel: RequestChannel, private def handleListOffsetRequestV1(request : RequestChannel.Request): Map[TopicPartition, ListOffsetResponse.PartitionData] = { val correlationId = request.header.correlationId val clientId = request.header.clientId - val offsetRequest = request.body.asInstanceOf[ListOffsetRequest] + val offsetRequest = request.body[ListOffsetRequest] val (authorizedRequestInfo, unauthorizedRequestInfo) = offsetRequest.partitionTimestamps.asScala.partition { case (topicPartition, _) => authorize(request.session, Describe, new Resource(auth.Topic, topicPartition.topic)) @@ -824,7 +817,7 @@ class KafkaApis(val requestChannel: RequestChannel, * Handle a topic metadata request */ def handleTopicMetadataRequest(request: RequestChannel.Request) { - val metadataRequest = request.body.asInstanceOf[MetadataRequest] + val metadataRequest = request.body[MetadataRequest] val requestVersion = request.header.apiVersion() val topics = @@ -889,8 +882,7 @@ class KafkaApis(val requestChannel: RequestChannel, brokers.map(_.getNode(request.listenerName)).asJava, clusterId, metadataCache.getControllerId.getOrElse(MetadataResponse.NO_CONTROLLER_ID), - completeTopicMetadata.asJava, - requestVersion + completeTopicMetadata.asJava ) requestChannel.sendResponse(new RequestChannel.Response(request, responseBody)) } @@ -900,7 +892,7 @@ class KafkaApis(val requestChannel: RequestChannel, */ def handleOffsetFetchRequest(request: RequestChannel.Request) { val header = request.header - val offsetFetchRequest = request.body.asInstanceOf[OffsetFetchRequest] + val offsetFetchRequest = request.body[OffsetFetchRequest] def authorizeTopicDescribe(partition: TopicPartition) = authorize(request.session, Describe, new Resource(auth.Topic, partition.topic)) @@ -938,7 +930,7 @@ class KafkaApis(val requestChannel: RequestChannel, }.toMap val unauthorizedPartitionData = unauthorizedPartitions.map(_ -> OffsetFetchResponse.UNKNOWN_PARTITION).toMap - new OffsetFetchResponse(Errors.NONE, (authorizedPartitionData ++ unauthorizedPartitionData).asJava, header.apiVersion) + new OffsetFetchResponse(Errors.NONE, (authorizedPartitionData ++ unauthorizedPartitionData).asJava) } else { // versions 1 and above read offsets from Kafka if (offsetFetchRequest.isAllPartitions) { @@ -948,7 +940,7 @@ class KafkaApis(val requestChannel: RequestChannel, else { // clients are not allowed to see offsets for topics that are not authorized for Describe val authorizedPartitionData = allPartitionData.filter { case (topicPartition, _) => authorizeTopicDescribe(topicPartition) } - new OffsetFetchResponse(Errors.NONE, authorizedPartitionData.asJava, header.apiVersion) + new OffsetFetchResponse(Errors.NONE, authorizedPartitionData.asJava) } } else { val (authorizedPartitions, unauthorizedPartitions) = offsetFetchRequest.partitions.asScala @@ -959,7 +951,7 @@ class KafkaApis(val requestChannel: RequestChannel, offsetFetchRequest.getErrorResponse(error) else { val unauthorizedPartitionData = unauthorizedPartitions.map(_ -> OffsetFetchResponse.UNKNOWN_PARTITION).toMap - new OffsetFetchResponse(Errors.NONE, (authorizedPartitionData ++ unauthorizedPartitionData).asJava, header.apiVersion) + new OffsetFetchResponse(Errors.NONE, (authorizedPartitionData ++ unauthorizedPartitionData).asJava) } } } @@ -970,7 +962,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleGroupCoordinatorRequest(request: RequestChannel.Request) { - val groupCoordinatorRequest = request.body.asInstanceOf[GroupCoordinatorRequest] + val groupCoordinatorRequest = request.body[GroupCoordinatorRequest] if (!authorize(request.session, Describe, new Resource(Group, groupCoordinatorRequest.groupId))) { val responseBody = new GroupCoordinatorResponse(Errors.GROUP_AUTHORIZATION_FAILED, Node.noNode) @@ -1003,7 +995,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleDescribeGroupRequest(request: RequestChannel.Request) { - val describeRequest = request.body.asInstanceOf[DescribeGroupsRequest] + val describeRequest = request.body[DescribeGroupsRequest] val groups = describeRequest.groupIds().asScala.map { groupId => if (!authorize(request.session, Describe, new Resource(Group, groupId))) { @@ -1036,12 +1028,12 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleJoinGroupRequest(request: RequestChannel.Request) { - val joinGroupRequest = request.body.asInstanceOf[JoinGroupRequest] + val joinGroupRequest = request.body[JoinGroupRequest] // the callback for sending a join-group response def sendResponseCallback(joinResult: JoinGroupResult) { val members = joinResult.members map { case (memberId, metadataArray) => (memberId, ByteBuffer.wrap(metadataArray)) } - val responseBody = new JoinGroupResponse(request.header.apiVersion, joinResult.error, joinResult.generationId, + val responseBody = new JoinGroupResponse(joinResult.error, joinResult.generationId, joinResult.subProtocol, joinResult.memberId, joinResult.leaderId, members.asJava) trace("Sending join group response %s for correlation id %d to client %s." @@ -1051,7 +1043,6 @@ class KafkaApis(val requestChannel: RequestChannel, if (!authorize(request.session, Read, new Resource(Group, joinGroupRequest.groupId()))) { val responseBody = new JoinGroupResponse( - request.header.apiVersion, Errors.GROUP_AUTHORIZATION_FAILED, JoinGroupResponse.UNKNOWN_GENERATION_ID, JoinGroupResponse.UNKNOWN_PROTOCOL, @@ -1077,7 +1068,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleSyncGroupRequest(request: RequestChannel.Request) { - val syncGroupRequest = request.body.asInstanceOf[SyncGroupRequest] + val syncGroupRequest = request.body[SyncGroupRequest] def sendResponseCallback(memberState: Array[Byte], error: Errors) { val responseBody = new SyncGroupResponse(error, ByteBuffer.wrap(memberState)) @@ -1098,7 +1089,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleHeartbeatRequest(request: RequestChannel.Request) { - val heartbeatRequest = request.body.asInstanceOf[HeartbeatRequest] + val heartbeatRequest = request.body[HeartbeatRequest] // the callback for sending a heartbeat response def sendResponseCallback(error: Errors) { @@ -1123,7 +1114,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleLeaveGroupRequest(request: RequestChannel.Request) { - val leaveGroupRequest = request.body.asInstanceOf[LeaveGroupRequest] + val leaveGroupRequest = request.body[LeaveGroupRequest] // the callback for sending a leave-group response def sendResponseCallback(error: Errors) { @@ -1157,11 +1148,11 @@ class KafkaApis(val requestChannel: RequestChannel, // If this is considered to leak information about the broker version a workaround is to use SSL // with client authentication which is performed at an earlier stage of the connection where the // ApiVersionRequest is not available. - val responseBody = if (Protocol.apiVersionSupported(ApiKeys.API_VERSIONS.id, request.header.apiVersion)) - ApiVersionsResponse.API_VERSIONS_RESPONSE - else - ApiVersionsResponse.fromError(Errors.UNSUPPORTED_VERSION) - requestChannel.sendResponse(new RequestChannel.Response(request, responseBody)) + val responseSend = + if (Protocol.apiVersionSupported(ApiKeys.API_VERSIONS.id, request.header.apiVersion)) + ApiVersionsResponse.API_VERSIONS_RESPONSE.toSend(request.connectionId, request.header) + else ApiVersionsResponse.unsupportedVersionSend(request.connectionId, request.header) + requestChannel.sendResponse(new RequestChannel.Response(request, responseSend)) } def close() { @@ -1170,10 +1161,10 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleCreateTopicsRequest(request: RequestChannel.Request) { - val createTopicsRequest = request.body.asInstanceOf[CreateTopicsRequest] + val createTopicsRequest = request.body[CreateTopicsRequest] def sendResponseCallback(results: Map[String, CreateTopicsResponse.Error]): Unit = { - val responseBody = new CreateTopicsResponse(results.asJava, request.header.apiVersion) + val responseBody = new CreateTopicsResponse(results.asJava) trace(s"Sending create topics response $responseBody for correlation id ${request.header.correlationId} to client ${request.header.clientId}.") requestChannel.sendResponse(new RequestChannel.Response(request, responseBody)) } @@ -1220,7 +1211,7 @@ class KafkaApis(val requestChannel: RequestChannel, } def handleDeleteTopicsRequest(request: RequestChannel.Request) { - val deleteTopicRequest = request.body.asInstanceOf[DeleteTopicsRequest] + val deleteTopicRequest = request.body[DeleteTopicsRequest] val (existingAndAuthorizedForDescribeTopics, nonExistingOrUnauthorizedForDescribeTopics) = deleteTopicRequest.topics.asScala.partition { topic => authorize(request.session, Describe, new Resource(auth.Topic, topic)) && metadataCache.contains(topic) http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala index c99d7c5..7fb02a3 100644 --- a/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala +++ b/core/src/main/scala/kafka/server/ReplicaFetcherThread.scala @@ -261,17 +261,13 @@ class ReplicaFetcherThread(name: String, } - private def earliestOrLatestOffset(topicPartition: TopicPartition, earliestOrLatest: Long, consumerId: Int): Long = { + private def earliestOrLatestOffset(topicPartition: TopicPartition, earliestOrLatest: Long, replicaId: Int): Long = { val requestBuilder = if (brokerConfig.interBrokerProtocolVersion >= KAFKA_0_10_1_IV2) { val partitions = Map(topicPartition -> (earliestOrLatest: java.lang.Long)) - new ListOffsetRequest.Builder(consumerId). - setTargetTimes(partitions.asJava). - setVersion(1) + ListOffsetRequest.Builder.forReplica(1, replicaId).setTargetTimes(partitions.asJava) } else { val partitions = Map(topicPartition -> new ListOffsetRequest.PartitionData(earliestOrLatest, 1)) - new ListOffsetRequest.Builder(consumerId). - setOffsetData(partitions.asJava). - setVersion(0) + ListOffsetRequest.Builder.forReplica(0, replicaId).setOffsetData(partitions.asJava) } val clientResponse = sendRequest(requestBuilder) val response = clientResponse.responseBody.asInstanceOf[ListOffsetResponse] @@ -295,9 +291,8 @@ class ReplicaFetcherThread(name: String, requestMap.put(topicPartition, new JFetchRequest.PartitionData(partitionFetchState.offset, fetchSize)) } - val requestBuilder = new JFetchRequest.Builder(maxWait, minBytes, requestMap). - setReplicaId(replicaId).setMaxBytes(maxBytes) - requestBuilder.setVersion(fetchRequestVersion) + val requestBuilder = JFetchRequest.Builder.forReplica(fetchRequestVersion, replicaId, maxWait, minBytes, requestMap) + .setMaxBytes(maxBytes) new FetchRequest(requestBuilder) } http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala index 3285bf2..4f71258 100644 --- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala +++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala @@ -27,7 +27,7 @@ import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener import org.apache.kafka.clients.consumer._ import org.apache.kafka.clients.producer.{KafkaProducer, ProducerRecord} import org.apache.kafka.common.errors._ -import org.apache.kafka.common.protocol.{ApiKeys, Errors, SecurityProtocol} +import org.apache.kafka.common.protocol.{ApiKeys, Errors, ProtoUtils, SecurityProtocol} import org.apache.kafka.common.requests._ import CreateTopicsRequest.TopicDetails import org.apache.kafka.common.security.auth.KafkaPrincipal @@ -191,11 +191,12 @@ class AuthorizerIntegrationTest extends BaseRequestTest { private def createFetchRequest = { val partitionMap = new util.LinkedHashMap[TopicPartition, requests.FetchRequest.PartitionData] partitionMap.put(tp, new requests.FetchRequest.PartitionData(0, 100)) - new requests.FetchRequest.Builder(100, Int.MaxValue, partitionMap).setReplicaId(5000).build() + val version = ProtoUtils.latestVersion(ApiKeys.FETCH.id) + requests.FetchRequest.Builder.forReplica(version, 5000, 100, Int.MaxValue, partitionMap).build() } private def createListOffsetsRequest = { - new requests.ListOffsetRequest.Builder().setTargetTimes( + requests.ListOffsetRequest.Builder.forConsumer(0).setTargetTimes( Map(tp -> (0L: java.lang.Long)).asJava). build() } @@ -214,7 +215,8 @@ class AuthorizerIntegrationTest extends BaseRequestTest { val brokers = Set(new requests.UpdateMetadataRequest.Broker(brokerId, Seq(new requests.UpdateMetadataRequest.EndPoint("localhost", 0, securityProtocol, ListenerName.forSecurityProtocol(securityProtocol))).asJava, null)).asJava - new requests.UpdateMetadataRequest.Builder(brokerId, Int.MaxValue, partitionState, brokers).build() + val version = ProtoUtils.latestVersion(ApiKeys.UPDATE_METADATA_KEY.id) + new requests.UpdateMetadataRequest.Builder(version, brokerId, Int.MaxValue, partitionState, brokers).build() } private def createJoinGroupRequest = { @@ -770,17 +772,18 @@ class AuthorizerIntegrationTest extends BaseRequestTest { @Test def testUnauthorizedDeleteWithoutDescribe() { - val response = send(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) - val deleteResponse = DeleteTopicsResponse.parse(response) - + val response = connectAndSend(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) + val version = ProtoUtils.latestVersion(ApiKeys.DELETE_TOPICS.id) + val deleteResponse = DeleteTopicsResponse.parse(response, version) assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, deleteResponse.errors.asScala.head._2) } @Test def testUnauthorizedDeleteWithDescribe() { addAndVerifyAcls(Set(new Acl(KafkaPrincipal.ANONYMOUS, Allow, Acl.WildCardHost, Describe)), deleteTopicResource) - val response = send(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) - val deleteResponse = DeleteTopicsResponse.parse(response) + val response = connectAndSend(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) + val version = ProtoUtils.latestVersion(ApiKeys.DELETE_TOPICS.id) + val deleteResponse = DeleteTopicsResponse.parse(response, version) assertEquals(Errors.TOPIC_AUTHORIZATION_FAILED, deleteResponse.errors.asScala.head._2) } @@ -788,8 +791,9 @@ class AuthorizerIntegrationTest extends BaseRequestTest { @Test def testDeleteWithWildCardAuth() { addAndVerifyAcls(Set(new Acl(KafkaPrincipal.ANONYMOUS, Allow, Acl.WildCardHost, Delete)), new Resource(Topic, "*")) - val response = send(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) - val deleteResponse = DeleteTopicsResponse.parse(response) + val response = connectAndSend(deleteTopicsRequest, ApiKeys.DELETE_TOPICS) + val version = ProtoUtils.latestVersion(ApiKeys.DELETE_TOPICS.id) + val deleteResponse = DeleteTopicsResponse.parse(response, version) assertEquals(Errors.NONE, deleteResponse.errors.asScala.head._2) } @@ -807,8 +811,9 @@ class AuthorizerIntegrationTest extends BaseRequestTest { isAuthorized: Boolean, isAuthorizedTopicDescribe: Boolean, topicExists: Boolean = true): AbstractResponse = { - val resp = send(request, apiKey) - val response = RequestKeyToResponseDeserializer(apiKey).getMethod("parse", classOf[ByteBuffer]).invoke(null, resp).asInstanceOf[AbstractResponse] + val resp = connectAndSend(request, apiKey) + val response = RequestKeyToResponseDeserializer(apiKey).getMethod("parse", classOf[ByteBuffer], classOf[Short]).invoke( + null, resp, request.version: java.lang.Short).asInstanceOf[AbstractResponse] val error = RequestKeyToError(apiKey).asInstanceOf[(AbstractResponse) => Errors](response) val authorizationErrorCodes = resources.flatMap { resourceType => @@ -877,7 +882,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest { private def sendOffsetFetchRequest(request: requests.OffsetFetchRequest, socketServer: SocketServer): requests.OffsetFetchResponse = { - val response = send(request, ApiKeys.OFFSET_FETCH, socketServer) + val response = connectAndSend(request, ApiKeys.OFFSET_FETCH, socketServer) requests.OffsetFetchResponse.parse(response, request.version) } http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/test/scala/unit/kafka/network/SocketServerTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 37bc238..3875604 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -33,7 +33,7 @@ import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.network.{ListenerName, NetworkSend} import org.apache.kafka.common.protocol.{ApiKeys, SecurityProtocol} import org.apache.kafka.common.record.MemoryRecords -import org.apache.kafka.common.requests.{ProduceRequest, RequestHeader} +import org.apache.kafka.common.requests.{AbstractRequest, ProduceRequest, RequestHeader} import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.utils.Time import org.junit.Assert._ @@ -89,13 +89,11 @@ class SocketServerTest extends JUnitSuite { } def processRequest(channel: RequestChannel, request: RequestChannel.Request) { - val byteBuffer = ByteBuffer.allocate(request.header.sizeOf + request.body.sizeOf) - request.header.writeTo(byteBuffer) - request.body.writeTo(byteBuffer) + val byteBuffer = request.body[AbstractRequest].serialize(request.header) byteBuffer.rewind() val send = new NetworkSend(request.connectionId, byteBuffer) - channel.sendResponse(new RequestChannel.Response(request.processor, request, send)) + channel.sendResponse(new RequestChannel.Response(request, send)) } def connect(s: SocketServer = server, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT) = { @@ -119,14 +117,11 @@ class SocketServerTest extends JUnitSuite { val ackTimeoutMs = 10000 val ack = 0: Short - val emptyRequest = new ProduceRequest.Builder( - ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() + val emptyRequest = new ProduceRequest.Builder(ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() val emptyHeader = new RequestHeader(apiKey, emptyRequest.version, clientId, correlationId) - - val byteBuffer = ByteBuffer.allocate(emptyHeader.sizeOf + emptyRequest.sizeOf) - emptyHeader.writeTo(byteBuffer) - emptyRequest.writeTo(byteBuffer) + val byteBuffer = emptyRequest.serialize(emptyHeader) byteBuffer.rewind() + val serializedBytes = new Array[Byte](byteBuffer.remaining) byteBuffer.get(serializedBytes) serializedBytes @@ -289,13 +284,10 @@ class SocketServerTest extends JUnitSuite { val clientId = "" val ackTimeoutMs = 10000 val ack = 0: Short - val emptyRequest = new ProduceRequest.Builder( - ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() + val emptyRequest = new ProduceRequest.Builder(ack, ackTimeoutMs, new HashMap[TopicPartition, MemoryRecords]()).build() val emptyHeader = new RequestHeader(apiKey, emptyRequest.version, clientId, correlationId) - val byteBuffer = ByteBuffer.allocate(emptyHeader.sizeOf() + emptyRequest.sizeOf()) - emptyHeader.writeTo(byteBuffer) - emptyRequest.writeTo(byteBuffer) + val byteBuffer = emptyRequest.serialize(emptyHeader) byteBuffer.rewind() val serializedBytes = new Array[Byte](byteBuffer.remaining) byteBuffer.get(serializedBytes) @@ -355,7 +347,7 @@ class SocketServerTest extends JUnitSuite { // detected. If the buffer is larger than 102400 bytes, a second write is attempted and it fails with an // IOException. val send = new NetworkSend(request.connectionId, ByteBuffer.allocate(550000)) - channel.sendResponse(new RequestChannel.Response(request.processor, request, send)) + channel.sendResponse(new RequestChannel.Response(request, send)) TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount, s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}") http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala b/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala index ef98531..448fce1 100644 --- a/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala +++ b/core/src/test/scala/unit/kafka/server/AbstractCreateTopicsRequestTest.scala @@ -82,13 +82,13 @@ class AbstractCreateTopicsRequestTest extends BaseRequestTest { protected def error(error: Errors, errorMessage: Option[String] = None): CreateTopicsResponse.Error = new CreateTopicsResponse.Error(error, errorMessage.orNull) - protected def duplicateFirstTopic(request: CreateTopicsRequest) = { + protected def toStructWithDuplicateFirstTopic(request: CreateTopicsRequest): Struct = { val struct = request.toStruct val topics = struct.getArray("create_topic_requests") val firstTopic = topics(0).asInstanceOf[Struct] val newTopics = firstTopic :: topics.toList struct.set("create_topic_requests", newTopics.toArray) - new CreateTopicsRequest(struct, request.version) + struct } protected def addPartitionsAndReplicationFactorToFirstTopic(request: CreateTopicsRequest) = { @@ -102,8 +102,10 @@ class AbstractCreateTopicsRequestTest extends BaseRequestTest { protected def validateErrorCreateTopicsRequests(request: CreateTopicsRequest, expectedResponse: Map[String, CreateTopicsResponse.Error], - checkErrorMessage: Boolean = true): Unit = { - val response = sendCreateTopicRequest(request) + checkErrorMessage: Boolean = true, + requestStruct: Option[Struct] = None): Unit = { + val response = requestStruct.map(sendCreateTopicRequestStruct(_, request.version)).getOrElse( + sendCreateTopicRequest(request)) val errors = response.errors.asScala assertEquals("The response size should match", expectedResponse.size, response.errors.size) @@ -133,14 +135,20 @@ class AbstractCreateTopicsRequestTest extends BaseRequestTest { assignments.map { case (k, v) => (k: Integer, v.map { i => i: Integer }.asJava) }.asJava } + protected def sendCreateTopicRequestStruct(requestStruct: Struct, apiVersion: Short, + socketServer: SocketServer = controllerSocketServer): CreateTopicsResponse = { + val response = connectAndSendStruct(requestStruct, ApiKeys.CREATE_TOPICS, apiVersion, socketServer) + CreateTopicsResponse.parse(response, apiVersion) + } + protected def sendCreateTopicRequest(request: CreateTopicsRequest, socketServer: SocketServer = controllerSocketServer): CreateTopicsResponse = { - val response = send(request, ApiKeys.CREATE_TOPICS, socketServer) + val response = connectAndSend(request, ApiKeys.CREATE_TOPICS, socketServer) CreateTopicsResponse.parse(response, request.version) } protected def sendMetadataRequest(request: MetadataRequest, destination: SocketServer = anySocketServer): MetadataResponse = { val version = ProtoUtils.latestVersion(ApiKeys.METADATA.id) - val response = send(request, ApiKeys.METADATA, destination = destination) + val response = connectAndSend(request, ApiKeys.METADATA, destination = destination) MetadataResponse.parse(response, version) } http://git-wip-us.apache.org/repos/asf/kafka/blob/fc1cfe47/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala b/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala index ffe82d1..248b91e 100644 --- a/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala +++ b/core/src/test/scala/unit/kafka/server/ApiVersionsRequestTest.scala @@ -17,8 +17,7 @@ package kafka.server -import org.apache.kafka.common.protocol.types.Struct -import org.apache.kafka.common.protocol.{ApiKeys, Errors, ProtoUtils} +import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.requests.ApiVersionsResponse.ApiVersion import org.apache.kafka.common.requests.{ApiVersionsRequest, ApiVersionsResponse} import org.junit.Assert._ @@ -51,14 +50,13 @@ class ApiVersionsRequestTest extends BaseRequestTest { @Test def testApiVersionsRequestWithUnsupportedVersion() { - val apiVersionsRequest = new ApiVersionsRequest( - new Struct(ProtoUtils.currentRequestSchema(ApiKeys.API_VERSIONS.id)), Short.MaxValue) - val apiVersionsResponse = sendApiVersionsRequest(apiVersionsRequest) + val apiVersionsRequest = new ApiVersionsRequest(0) + val apiVersionsResponse = sendApiVersionsRequest(apiVersionsRequest, Some(Short.MaxValue)) assertEquals(Errors.UNSUPPORTED_VERSION, apiVersionsResponse.error) } - private def sendApiVersionsRequest(request: ApiVersionsRequest): ApiVersionsResponse = { - val response = send(request, ApiKeys.API_VERSIONS) - ApiVersionsResponse.parse(response) + private def sendApiVersionsRequest(request: ApiVersionsRequest, apiVersion: Option[Short] = None): ApiVersionsResponse = { + val response = connectAndSend(request, ApiKeys.API_VERSIONS, apiVersion = apiVersion) + ApiVersionsResponse.parse(response, 0) } }