dajac commented on code in PR #13231: URL: https://github.com/apache/kafka/pull/13231#discussion_r1118380508
########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -34,22 +43,40 @@ public class AddPartitionsToTxnRequest extends AbstractRequest { private final AddPartitionsToTxnRequestData data; - private List<TopicPartition> cachedPartitions = null; + private final short version; Review Comment: The version is already in the base class. Do we really need it here? ########## clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java: ########## @@ -1328,7 +1328,7 @@ Priority priority() { @Override public void handleResponse(AbstractResponse response) { AddPartitionsToTxnResponse addPartitionsToTxnResponse = (AddPartitionsToTxnResponse) response; - Map<TopicPartition, Errors> errors = addPartitionsToTxnResponse.errors(); + Map<TopicPartition, Errors> errors = addPartitionsToTxnResponse.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID); Review Comment: nit: I suppose that `errors` should never be `null` here. I wonder if we should still check it. What do you think? ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -34,22 +43,40 @@ public class AddPartitionsToTxnRequest extends AbstractRequest { private final AddPartitionsToTxnRequestData data; - private List<TopicPartition> cachedPartitions = null; + private final short version; public static class Builder extends AbstractRequest.Builder<AddPartitionsToTxnRequest> { public final AddPartitionsToTxnRequestData data; + + public static Builder forClient(String transactionalId, + long producerId, + short producerEpoch, + List<TopicPartition> partitions) { + + AddPartitionsToTxnTopicCollection topics = buildTxnTopicCollection(partitions); + + return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), + (short) 3, + new AddPartitionsToTxnRequestData() + .setV3AndBelowTransactionalId(transactionalId) + .setV3AndBelowProducerId(producerId) + .setV3AndBelowProducerEpoch(producerEpoch) + .setV3AndBelowTopics(topics)); Review Comment: nit: The indentation of the arguments looks inconsistent. ########## clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json: ########## @@ -23,17 +23,39 @@ // Version 2 adds the support for new error code PRODUCER_FENCED. // // Version 3 enables flexible versions. - "validVersions": "0-3", + // + // Version 4 adds VerifyOnly field to check if partitions are already in transaction and adds support to batch multiple transactions. Review Comment: I think that we should explain that v4 is only for other brokers and clients are suppose to use version <= v3. ########## clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java: ########## @@ -1303,11 +1305,13 @@ public void testCommitWithTopicAuthorizationFailureInAddPartitionsInFlight() thr Map<TopicPartition, Errors> errors = new HashMap<>(); errors.put(tp0, Errors.TOPIC_AUTHORIZATION_FAILED); errors.put(tp1, Errors.OPERATION_NOT_ATTEMPTED); + AddPartitionsToTxnResult result = AddPartitionsToTxnResponse.resultForTransaction(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID, errors); + AddPartitionsToTxnResponseData data = new AddPartitionsToTxnResponseData().setResultsByTopicV3AndBelow(result.topicResults()).setThrottleTimeMs(0); client.respond(body -> { AddPartitionsToTxnRequest request = (AddPartitionsToTxnRequest) body; - assertEquals(new HashSet<>(request.partitions()), new HashSet<>(errors.keySet())); + assertEquals(new HashSet<>(AddPartitionsToTxnRequest.getPartitions(request.data().v3AndBelowTopics())), new HashSet<>(errors.keySet())); Review Comment: `AddPartitionsToTxnRequest.getPartitions(request.data().v3AndBelowTopics()))` looks a bit weird from an encapsulation perspective. Why not just keeping `partitions` as before if you need it? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { Review Comment: nit: `responses.forEach(result => {` -> `responses.forEach { result => ` ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -118,11 +123,78 @@ public AddPartitionsToTxnRequestData data() { @Override public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs, Throwable e) { - final HashMap<TopicPartition, Errors> errors = new HashMap<>(); - for (TopicPartition partition : partitions()) { - errors.put(partition, Errors.forException(e)); + Errors error = Errors.forException(e); + AddPartitionsToTxnResponseData response = new AddPartitionsToTxnResponseData(); + if (version < 4) { + response.setResultsByTopicV3AndBelow(errorResponseForTopics(data.v3AndBelowTopics(), error)); + } else { + AddPartitionsToTxnResultCollection results = new AddPartitionsToTxnResultCollection(); + for (AddPartitionsToTxnTransaction transaction : data().transactions()) { + results.add(errorResponseForTransaction(transaction.transactionalId(), error)); + } + response.setResultsByTransaction(results); + response.setErrorCode(error.code()); Review Comment: When there is a global error, do we expect to set both the top level error and to return an error for each transaction? If we alway set both, what is the purpose of the top level error? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + }) + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def maybeSendResponse(): Unit = { + var canSend = false + responses.synchronized { + if (responses.size() == txns.size()) { + canSend = true + } + } + if (canSend) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + } + + txns.forEach( transaction => { Review Comment: nit: `txns.forEach( transaction => {` -> `txns.forEach { transaction =>`. ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + }) + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def maybeSendResponse(): Unit = { + var canSend = false + responses.synchronized { + if (responses.size() == txns.size()) { + canSend = true + } + } + if (canSend) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + } + + txns.forEach( transaction => { + val transactionalId = transaction.transactionalId + val partitionsToAdd = partitionsByTransaction.get(transactionalId).asScala + + // Versions < 4 come from clients and must be authorized to write for the given transaction and for the given topics. + if (version < 4 && !authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + responses.add(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED)) + maybeSendResponse() } else { - def sendResponseCallback(error: Errors): Unit = { - def createResponse(requestThrottleMs: Int): AbstractResponse = { - val finalError = - if (addPartitionsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedPartitions = mutable.Set[TopicPartition]() + + val authorizedTopics = if (version < 4) authHelper.filterByAuthorized(request.context, WRITE, TOPIC, + partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) else partitionsToAdd.map(_.topic).toSet Review Comment: nit: I find this line quite hard to read. Should we put it on multiple lines? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + }) + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def maybeSendResponse(): Unit = { + var canSend = false + responses.synchronized { + if (responses.size() == txns.size()) { + canSend = true + } + } + if (canSend) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + } + + txns.forEach( transaction => { + val transactionalId = transaction.transactionalId + val partitionsToAdd = partitionsByTransaction.get(transactionalId).asScala + + // Versions < 4 come from clients and must be authorized to write for the given transaction and for the given topics. + if (version < 4 && !authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + responses.add(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED)) + maybeSendResponse() } else { - def sendResponseCallback(error: Errors): Unit = { - def createResponse(requestThrottleMs: Int): AbstractResponse = { - val finalError = - if (addPartitionsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedPartitions = mutable.Set[TopicPartition]() + + val authorizedTopics = if (version < 4) authHelper.filterByAuthorized(request.context, WRITE, TOPIC, + partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) else partitionsToAdd.map(_.topic).toSet + for (topicPartition <- partitionsToAdd) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION + else + authorizedPartitions.add(topicPartition) + } + + if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { + // Any failed partition check causes the entire transaction to fail. We send the appropriate error codes for the + // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded + // the authorization check to indicate that they were not added to the transaction. + val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ + authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) + responses.add(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitionErrors.asJava)) + maybeSendResponse() + } else { + def sendResponseCallback(error: Errors): Unit = { + val finalError = { + if (version < 2 && error == Errors.PRODUCER_FENCED) { // For older clients, they could not understand the new PRODUCER_FENCED error code, // so we need to return the old INVALID_PRODUCER_EPOCH to have the same client handling logic. Errors.INVALID_PRODUCER_EPOCH } else { error } - - val responseBody: AddPartitionsToTxnResponse = new AddPartitionsToTxnResponse(requestThrottleMs, - partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava) - trace(s"Completed $transactionalId's AddPartitionsToTxnRequest with partitions $partitionsToAdd: errors: $error from client ${request.header.clientId}") - responseBody + } + responses.synchronized { + responses.add(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, finalError)) + } + maybeSendResponse() + } + + def sendVerifyResponseCallback(errors: Map[TopicPartition, Errors]): Unit = { + responses.synchronized { + responses.add(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors.asJava)) + } + maybeSendResponse() Review Comment: This pattern is used in a few places. Would having `addResponseAndMaybeSend` helper method make sense? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2493,7 +2542,9 @@ class KafkaApis(val requestChannel: RequestChannel, addOffsetsToTxnRequest.data.producerId, addOffsetsToTxnRequest.data.producerEpoch, Set(offsetTopicPartition), + false, sendResponseCallback, + null, Review Comment: This tend to confirm that having two callbacks is weird. ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) Review Comment: nit: `()` are not necessary. There are few similar cases in this file. ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) Review Comment: nit: `response.errorsForTransaction(response.getTransactionTopicResults(transactionalId))` looks really weird from an encapsulation perspective. Can't we use `response.errors.get(transactionalId)`? ########## clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json: ########## @@ -23,17 +23,39 @@ // Version 2 adds the support for new error code PRODUCER_FENCED. // // Version 3 enables flexible versions. - "validVersions": "0-3", + // + // Version 4 adds VerifyOnly field to check if partitions are already in transaction and adds support to batch multiple transactions. + // The AddPartitionsToTxnRequest version 4 API is added as part of KIP-890 and is still + // under developement. Hence, the API is not exposed by default by brokers + // unless explicitely enabled. + "latestVersionUnstable": true, + "validVersions": "0-4", "flexibleVersions": "3+", "fields": [ - { "name": "TransactionalId", "type": "string", "versions": "0+", "entityType": "transactionalId", - "about": "The transactional id corresponding to the transaction."}, - { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": "producerId", + { "name": "Transactions", "type": "[]AddPartitionsToTxnTransaction", "versions": "4+", + "about": "List of transactions to add partitions to.", "fields": [ + { "name": "TransactionalId", "type": "string", "versions": "4+", "mapKey": true, "entityType": "transactionalId", + "about": "The transactional id corresponding to the transaction." }, + { "name": "ProducerId", "type": "int64", "versions": "4+", "entityType": "producerId", + "about": "Current producer id in use by the transactional id." }, + { "name": "ProducerEpoch", "type": "int16", "versions": "4+", + "about": "Current epoch associated with the producer id." }, + { "name": "VerifyOnly", "type": "bool", "versions": "4+", "default": false, + "about": "Boolean to signify if we want to check if the partition is in the transaction rather than add it." }, + { "name": "Topics", "type": "[]AddPartitionsToTxnTopic", "versions": "4+", + "about": "The partitions to add to the transaction." } + ]}, + { "name": "V3AndBelowTransactionalId", "type": "string", "versions": "0-3", "entityType": "transactionalId", Review Comment: Is it common to prefix old fields by their version? It is the first time I see it. ########## clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json: ########## @@ -22,22 +22,37 @@ // Version 2 adds the support for new error code PRODUCER_FENCED. // // Version 3 enables flexible versions. - "validVersions": "0-3", + // + // Version 4 adds support to batch multiple transactions and a top level error code. + "validVersions": "0-4", "flexibleVersions": "3+", "fields": [ { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+", "about": "Duration in milliseconds for which the request was throttled due to a quota violation, or zero if the request did not violate any quota." }, - { "name": "Results", "type": "[]AddPartitionsToTxnTopicResult", "versions": "0+", - "about": "The results for each topic.", "fields": [ + { "name": "ErrorCode", "type": "int16", "versions": "4+", Review Comment: Should we make it ignorable in case we would set it by mistake? ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) + + val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn1) + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId2) + .setProducerEpoch(producerEpoch2) + .setVerifyOnly(false) + .setTopics(txn2Topics)) + + val request = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(coordinatorId)) + + val errors = response.errors() + + assertTrue(errors.containsKey(transactionalId1)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.NONE, errors.get(transactionalId1).get(tp0)) + + assertTrue(errors.containsKey(transactionalId2)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, errors.get(transactionalId2).get(tp0)) Review Comment: nit: How about using `assertEquals(Map(... define the expect map ..., errors)`? This is simpler and has the benefit of ensuring that we only have what we expect in `errors`. Note that assertion as L148 is not correct. `transactionalId2` should be used. ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) + + val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn1) + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId2) + .setProducerEpoch(producerEpoch2) + .setVerifyOnly(false) + .setTopics(txn2Topics)) + + val request = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(coordinatorId)) + + val errors = response.errors() + + assertTrue(errors.containsKey(transactionalId1)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.NONE, errors.get(transactionalId1).get(tp0)) + + assertTrue(errors.containsKey(transactionalId2)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, errors.get(transactionalId2).get(tp0)) + } - assertEquals(2, response.errors.size) + @Test + def testVerifyOnly(): Unit = { + val tp0 = new TopicPartition(topic1, 0) - assertTrue(response.errors.containsKey(createdTopicPartition)) - assertEquals(Errors.OPERATION_NOT_ATTEMPTED, response.errors.get(createdTopicPartition)) + val transactionalId = "foobar" + val (coordinatorId, txn) = setUpTransactions(transactionalId, true, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn) + + val verifyRequest = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val verifyResponse = connectAndReceive[AddPartitionsToTxnResponse](verifyRequest, brokerSocketServer(coordinatorId)) + + val verifyErrors = verifyResponse.errors() + + assertTrue(verifyErrors.containsKey(transactionalId)) + assertTrue(verifyErrors.get(transactionalId).containsKey(tp0)) + assertEquals(Errors.INVALID_TXN_STATE, verifyErrors.get(transactionalId).get(tp0)) + } + + private def setUpTransactions(transactionalId: String, verifyOnly: Boolean, partitions: Set[TopicPartition]): (Int, AddPartitionsToTxnTransaction) = { + val findCoordinatorRequest = new FindCoordinatorRequest.Builder(new FindCoordinatorRequestData().setKey(transactionalId).setKeyType(CoordinatorType.TRANSACTION.id)).build() + // First find coordinator request creates the state topic, then wait for transactional topics to be created. + connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + TestUtils.waitForAllPartitionsMetadata(brokers, "__transaction_state", 50) + val findCoordinatorResponse = connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + val coordinatorId = findCoordinatorResponse.data().coordinators().get(0).nodeId() + + val initPidRequest = new InitProducerIdRequest.Builder(new InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build() + val initPidResponse = connectAndReceive[InitProducerIdResponse](initPidRequest, brokerSocketServer(coordinatorId)) + + val producerId1 = initPidResponse.data().producerId() + val producerEpoch1 = initPidResponse.data().producerEpoch() + + val txn1Topics = new AddPartitionsToTxnTopicCollection() + partitions.foreach(tp => + txn1Topics.add(new AddPartitionsToTxnTopic() + .setName(tp.topic()) + .setPartitions(Collections.singletonList(tp.partition()))) + ) + + (coordinatorId, new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId1) + .setProducerEpoch(producerEpoch1) + .setVerifyOnly(verifyOnly) + .setTopics(txn1Topics)) + } +} - assertTrue(response.errors.containsKey(nonExistentTopic)) - assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors.get(nonExistentTopic)) +object AddPartitionsToTxnRequestServerTest { + def parameters: JStream[Arguments] = { + val arguments = mutable.ListBuffer[Arguments]() + ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions().forEach( version => + Array("kraft", "zk").foreach( quorum => Review Comment: nit: ditto. ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) + + val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn1) + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId2) + .setProducerEpoch(producerEpoch2) + .setVerifyOnly(false) + .setTopics(txn2Topics)) + + val request = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(coordinatorId)) + + val errors = response.errors() + + assertTrue(errors.containsKey(transactionalId1)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.NONE, errors.get(transactionalId1).get(tp0)) + + assertTrue(errors.containsKey(transactionalId2)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, errors.get(transactionalId2).get(tp0)) + } - assertEquals(2, response.errors.size) + @Test + def testVerifyOnly(): Unit = { + val tp0 = new TopicPartition(topic1, 0) - assertTrue(response.errors.containsKey(createdTopicPartition)) - assertEquals(Errors.OPERATION_NOT_ATTEMPTED, response.errors.get(createdTopicPartition)) + val transactionalId = "foobar" + val (coordinatorId, txn) = setUpTransactions(transactionalId, true, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn) + + val verifyRequest = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val verifyResponse = connectAndReceive[AddPartitionsToTxnResponse](verifyRequest, brokerSocketServer(coordinatorId)) + + val verifyErrors = verifyResponse.errors() + + assertTrue(verifyErrors.containsKey(transactionalId)) + assertTrue(verifyErrors.get(transactionalId).containsKey(tp0)) + assertEquals(Errors.INVALID_TXN_STATE, verifyErrors.get(transactionalId).get(tp0)) Review Comment: ditto. ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + }) + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def maybeSendResponse(): Unit = { + var canSend = false + responses.synchronized { + if (responses.size() == txns.size()) { + canSend = true + } + } + if (canSend) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + } + + txns.forEach( transaction => { + val transactionalId = transaction.transactionalId + val partitionsToAdd = partitionsByTransaction.get(transactionalId).asScala + + // Versions < 4 come from clients and must be authorized to write for the given transaction and for the given topics. + if (version < 4 && !authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + responses.add(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED)) + maybeSendResponse() } else { - def sendResponseCallback(error: Errors): Unit = { - def createResponse(requestThrottleMs: Int): AbstractResponse = { - val finalError = - if (addPartitionsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedPartitions = mutable.Set[TopicPartition]() + + val authorizedTopics = if (version < 4) authHelper.filterByAuthorized(request.context, WRITE, TOPIC, + partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) else partitionsToAdd.map(_.topic).toSet + for (topicPartition <- partitionsToAdd) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION + else + authorizedPartitions.add(topicPartition) + } + + if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { + // Any failed partition check causes the entire transaction to fail. We send the appropriate error codes for the + // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded + // the authorization check to indicate that they were not added to the transaction. + val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ + authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) + responses.add(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitionErrors.asJava)) + maybeSendResponse() + } else { + def sendResponseCallback(error: Errors): Unit = { + val finalError = { + if (version < 2 && error == Errors.PRODUCER_FENCED) { // For older clients, they could not understand the new PRODUCER_FENCED error code, // so we need to return the old INVALID_PRODUCER_EPOCH to have the same client handling logic. Errors.INVALID_PRODUCER_EPOCH } else { error } - - val responseBody: AddPartitionsToTxnResponse = new AddPartitionsToTxnResponse(requestThrottleMs, - partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava) - trace(s"Completed $transactionalId's AddPartitionsToTxnRequest with partitions $partitionsToAdd: errors: $error from client ${request.header.clientId}") - responseBody + } + responses.synchronized { + responses.add(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, finalError)) + } + maybeSendResponse() + } + + def sendVerifyResponseCallback(errors: Map[TopicPartition, Errors]): Unit = { + responses.synchronized { + responses.add(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, errors.asJava)) + } + maybeSendResponse() } - requestHelper.sendResponseMaybeThrottle(request, createResponse) + txnCoordinator.handleAddPartitionsToTransaction(transactionalId, + transaction.producerId, + transaction.producerEpoch, + authorizedPartitions, + transaction.verifyOnly, + sendResponseCallback, + sendVerifyResponseCallback, Review Comment: I am not sure to understand why we need two callbacks. I find it weird in the first place. My understanding is that in both cases, we end up with an error per partition in the response so it seems to me that we could unify them, no? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,68 +2385,116 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach(result => { + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + }) + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def maybeSendResponse(): Unit = { + var canSend = false + responses.synchronized { + if (responses.size() == txns.size()) { + canSend = true + } + } Review Comment: nit: A more scala-ish version of this: ``` val canSend = responses.synchronized { responses.size == txns.size } ``` ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() Review Comment: nit: `build()` looks weird here. Should we put the closing parenthesis of `forClient` on a new line: `).build()`? ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) + + val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn1) + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId2) + .setProducerEpoch(producerEpoch2) + .setVerifyOnly(false) + .setTopics(txn2Topics)) + + val request = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(coordinatorId)) + + val errors = response.errors() + + assertTrue(errors.containsKey(transactionalId1)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.NONE, errors.get(transactionalId1).get(tp0)) + + assertTrue(errors.containsKey(transactionalId2)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, errors.get(transactionalId2).get(tp0)) + } - assertEquals(2, response.errors.size) + @Test + def testVerifyOnly(): Unit = { + val tp0 = new TopicPartition(topic1, 0) - assertTrue(response.errors.containsKey(createdTopicPartition)) - assertEquals(Errors.OPERATION_NOT_ATTEMPTED, response.errors.get(createdTopicPartition)) + val transactionalId = "foobar" + val (coordinatorId, txn) = setUpTransactions(transactionalId, true, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn) + + val verifyRequest = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val verifyResponse = connectAndReceive[AddPartitionsToTxnResponse](verifyRequest, brokerSocketServer(coordinatorId)) + + val verifyErrors = verifyResponse.errors() + + assertTrue(verifyErrors.containsKey(transactionalId)) + assertTrue(verifyErrors.get(transactionalId).containsKey(tp0)) + assertEquals(Errors.INVALID_TXN_STATE, verifyErrors.get(transactionalId).get(tp0)) + } + + private def setUpTransactions(transactionalId: String, verifyOnly: Boolean, partitions: Set[TopicPartition]): (Int, AddPartitionsToTxnTransaction) = { + val findCoordinatorRequest = new FindCoordinatorRequest.Builder(new FindCoordinatorRequestData().setKey(transactionalId).setKeyType(CoordinatorType.TRANSACTION.id)).build() + // First find coordinator request creates the state topic, then wait for transactional topics to be created. + connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + TestUtils.waitForAllPartitionsMetadata(brokers, "__transaction_state", 50) + val findCoordinatorResponse = connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + val coordinatorId = findCoordinatorResponse.data().coordinators().get(0).nodeId() + + val initPidRequest = new InitProducerIdRequest.Builder(new InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build() + val initPidResponse = connectAndReceive[InitProducerIdResponse](initPidRequest, brokerSocketServer(coordinatorId)) + + val producerId1 = initPidResponse.data().producerId() + val producerEpoch1 = initPidResponse.data().producerEpoch() + + val txn1Topics = new AddPartitionsToTxnTopicCollection() + partitions.foreach(tp => + txn1Topics.add(new AddPartitionsToTxnTopic() + .setName(tp.topic()) + .setPartitions(Collections.singletonList(tp.partition()))) + ) + + (coordinatorId, new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId1) + .setProducerEpoch(producerEpoch1) + .setVerifyOnly(verifyOnly) + .setTopics(txn1Topics)) + } +} - assertTrue(response.errors.containsKey(nonExistentTopic)) - assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, response.errors.get(nonExistentTopic)) +object AddPartitionsToTxnRequestServerTest { + def parameters: JStream[Arguments] = { + val arguments = mutable.ListBuffer[Arguments]() + ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions().forEach( version => Review Comment: nit: `forEach { `. ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,149 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava) + .build() + } else { + val topics = new AddPartitionsToTxnTopicCollection() + topics.add(new AddPartitionsToTxnTopic() + .setName(createdTopicPartition.topic()) + .setPartitions(Collections.singletonList(createdTopicPartition.partition()))) + topics.add(new AddPartitionsToTxnTopic() + .setName(nonExistentTopic.topic()) + .setPartitions(Collections.singletonList(nonExistentTopic.partition()))) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId) + .setProducerId(producerId) + .setProducerEpoch(producerEpoch) + .setVerifyOnly(false) + .setTopics(topics)) + AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + } val leaderId = brokers.head.config.brokerId val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(leaderId)) + + val errors = + if (version < 4) + response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) + else + response.errorsForTransaction(response.getTransactionTopicResults(transactionalId)) + + assertEquals(2, errors.size) + + assertTrue(errors.containsKey(createdTopicPartition)) + assertEquals(Errors.OPERATION_NOT_ATTEMPTED, errors.get(createdTopicPartition)) + + assertTrue(errors.containsKey(nonExistentTopic)) + assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, errors.get(nonExistentTopic)) + } + + @Test + def testOneSuccessOneErrorInBatchedRequest(): Unit = { + val tp0 = new TopicPartition(topic1, 0) + val transactionalId1 = "foobar" + val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction coordinator + val producerId2 = 1000L + val producerEpoch2: Short = 0 + + val txn2Topics = new AddPartitionsToTxnTopicCollection() + txn2Topics.add(new AddPartitionsToTxnTopic() + .setName(tp0.topic()) + .setPartitions(Collections.singletonList(tp0.partition()))) + + val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn1) + transactions.add(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId2) + .setProducerEpoch(producerEpoch2) + .setVerifyOnly(false) + .setTopics(txn2Topics)) + + val request = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val response = connectAndReceive[AddPartitionsToTxnResponse](request, brokerSocketServer(coordinatorId)) + + val errors = response.errors() + + assertTrue(errors.containsKey(transactionalId1)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.NONE, errors.get(transactionalId1).get(tp0)) + + assertTrue(errors.containsKey(transactionalId2)) + assertTrue(errors.get(transactionalId1).containsKey(tp0)) + assertEquals(Errors.INVALID_PRODUCER_ID_MAPPING, errors.get(transactionalId2).get(tp0)) + } - assertEquals(2, response.errors.size) + @Test + def testVerifyOnly(): Unit = { + val tp0 = new TopicPartition(topic1, 0) - assertTrue(response.errors.containsKey(createdTopicPartition)) - assertEquals(Errors.OPERATION_NOT_ATTEMPTED, response.errors.get(createdTopicPartition)) + val transactionalId = "foobar" + val (coordinatorId, txn) = setUpTransactions(transactionalId, true, Set(tp0)) + + val transactions = new AddPartitionsToTxnTransactionCollection() + transactions.add(txn) + + val verifyRequest = AddPartitionsToTxnRequest.Builder.forBroker(transactions).build() + + val verifyResponse = connectAndReceive[AddPartitionsToTxnResponse](verifyRequest, brokerSocketServer(coordinatorId)) + + val verifyErrors = verifyResponse.errors() + + assertTrue(verifyErrors.containsKey(transactionalId)) + assertTrue(verifyErrors.get(transactionalId).containsKey(tp0)) + assertEquals(Errors.INVALID_TXN_STATE, verifyErrors.get(transactionalId).get(tp0)) + } + + private def setUpTransactions(transactionalId: String, verifyOnly: Boolean, partitions: Set[TopicPartition]): (Int, AddPartitionsToTxnTransaction) = { + val findCoordinatorRequest = new FindCoordinatorRequest.Builder(new FindCoordinatorRequestData().setKey(transactionalId).setKeyType(CoordinatorType.TRANSACTION.id)).build() + // First find coordinator request creates the state topic, then wait for transactional topics to be created. + connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + TestUtils.waitForAllPartitionsMetadata(brokers, "__transaction_state", 50) + val findCoordinatorResponse = connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, brokerSocketServer(brokers.head.config.brokerId)) + val coordinatorId = findCoordinatorResponse.data().coordinators().get(0).nodeId() + + val initPidRequest = new InitProducerIdRequest.Builder(new InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build() + val initPidResponse = connectAndReceive[InitProducerIdResponse](initPidRequest, brokerSocketServer(coordinatorId)) + + val producerId1 = initPidResponse.data().producerId() + val producerEpoch1 = initPidResponse.data().producerEpoch() + + val txn1Topics = new AddPartitionsToTxnTopicCollection() + partitions.foreach(tp => Review Comment: nit: `partitions.foreach {` ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java: ########## @@ -48,29 +51,50 @@ public class AddPartitionsToTxnResponse extends AbstractResponse { private final AddPartitionsToTxnResponseData data; - private Map<TopicPartition, Errors> cachedErrorsMap = null; + public static final String V3_AND_BELOW_TXN_ID = ""; public AddPartitionsToTxnResponse(AddPartitionsToTxnResponseData data) { super(ApiKeys.ADD_PARTITIONS_TO_TXN); this.data = data; } - public AddPartitionsToTxnResponse(int throttleTimeMs, Map<TopicPartition, Errors> errors) { - super(ApiKeys.ADD_PARTITIONS_TO_TXN); + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + @Override + public void maybeSetThrottleTimeMs(int throttleTimeMs) { + data.setThrottleTimeMs(throttleTimeMs); + } + + public Map<String, Map<TopicPartition, Errors>> errors() { + Map<String, Map<TopicPartition, Errors>> errorsMap = new HashMap<>(); + + errorsMap.put(V3_AND_BELOW_TXN_ID, errorsForTransaction(this.data.resultsByTopicV3AndBelow())); + + for (AddPartitionsToTxnResult result : this.data.resultsByTransaction()) { + String transactionalId = result.transactionalId(); + errorsMap.put(transactionalId, errorsForTransaction(data().resultsByTransaction().find(transactionalId).topicResults())); Review Comment: Can't you reuse `result` instead of calling `data().resultsByTransaction().find(transactionalId)`? ########## core/src/test/scala/unit/kafka/server/KafkaApisTest.scala: ########## @@ -1962,7 +1962,9 @@ class KafkaApisTest { ArgumentMatchers.eq(producerId), ArgumentMatchers.eq(epoch), ArgumentMatchers.eq(Set(new TopicPartition(Topic.GROUP_METADATA_TOPIC_NAME, partition))), + ArgumentMatchers.eq(false), Review Comment: It would be great if we could also add new unit tests to cover the batch mode in `KafkaApisTest`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org