dajac commented on code in PR #13231: URL: https://github.com/apache/kafka/pull/13231#discussion_r1124792127
########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -34,22 +43,38 @@ public class AddPartitionsToTxnRequest extends AbstractRequest { private final AddPartitionsToTxnRequestData data; - private List<TopicPartition> cachedPartitions = null; - public static class Builder extends AbstractRequest.Builder<AddPartitionsToTxnRequest> { public final AddPartitionsToTxnRequestData data; + + public static Builder forClient(String transactionalId, + long producerId, + short producerEpoch, + List<TopicPartition> partitions) { + + AddPartitionsToTxnTopicCollection topics = buildTxnTopicCollection(partitions); + + return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), + (short) 3, Review Comment: nit: Should we put `(short) 3` on the previous line to be consistent with how you did it at L66? ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -34,22 +43,38 @@ public class AddPartitionsToTxnRequest extends AbstractRequest { private final AddPartitionsToTxnRequestData data; - private List<TopicPartition> cachedPartitions = null; - public static class Builder extends AbstractRequest.Builder<AddPartitionsToTxnRequest> { public final AddPartitionsToTxnRequestData data; + + public static Builder forClient(String transactionalId, + long producerId, + short producerEpoch, + List<TopicPartition> partitions) { + + AddPartitionsToTxnTopicCollection topics = buildTxnTopicCollection(partitions); + + return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), + (short) 3, + new AddPartitionsToTxnRequestData() + .setV3AndBelowTransactionalId(transactionalId) + .setV3AndBelowProducerId(producerId) + .setV3AndBelowProducerEpoch(producerEpoch) + .setV3AndBelowTopics(topics)); + } + + public static Builder forBroker(AddPartitionsToTxnTransactionCollection transactions) { + return new Builder((short) 4, ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion(), + new AddPartitionsToTxnRequestData() + .setTransactions(transactions)); + } + + public Builder(short minVersion, short maxVersion, AddPartitionsToTxnRequestData data) { Review Comment: nit: Do we still use this constructor anywhere? It may be good to make it private or package private to ensure that `forClient` or `forBroker` is used. ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java: ########## @@ -48,29 +51,51 @@ public class AddPartitionsToTxnResponse extends AbstractResponse { private final AddPartitionsToTxnResponseData data; - private Map<TopicPartition, Errors> cachedErrorsMap = null; + public static final String V3_AND_BELOW_TXN_ID = ""; public AddPartitionsToTxnResponse(AddPartitionsToTxnResponseData data) { super(ApiKeys.ADD_PARTITIONS_TO_TXN); this.data = data; } - public AddPartitionsToTxnResponse(int throttleTimeMs, Map<TopicPartition, Errors> errors) { - super(ApiKeys.ADD_PARTITIONS_TO_TXN); + @Override + public int throttleTimeMs() { + return data.throttleTimeMs(); + } + + @Override + public void maybeSetThrottleTimeMs(int throttleTimeMs) { + data.setThrottleTimeMs(throttleTimeMs); + } + + public Map<String, Map<TopicPartition, Errors>> errors() { + Map<String, Map<TopicPartition, Errors>> errorsMap = new HashMap<>(); + + if (this.data.resultsByTopicV3AndBelow().size() != 0) { Review Comment: nit: I think that we usually prefer using `isEmpty()`. ########## clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java: ########## @@ -58,42 +65,75 @@ public void setUp() { errorsMap.put(tp2, errorTwo); } - @Test - public void testConstructorWithErrorResponse() { - AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(throttleTimeMs, errorsMap); - - assertEquals(expectedErrorCounts, response.errorCounts()); - assertEquals(throttleTimeMs, response.throttleTimeMs()); - } - - @Test - public void testParse() { - + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN) + public void testParse(short version) { AddPartitionsToTxnTopicResultCollection topicCollection = new AddPartitionsToTxnTopicResultCollection(); AddPartitionsToTxnTopicResult topicResult = new AddPartitionsToTxnTopicResult(); topicResult.setName(topicOne); - topicResult.results().add(new AddPartitionsToTxnPartitionResult() - .setErrorCode(errorOne.code()) + topicResult.resultsByPartition().add(new AddPartitionsToTxnPartitionResult() + .setPartitionErrorCode(errorOne.code()) .setPartitionIndex(partitionOne)); - topicResult.results().add(new AddPartitionsToTxnPartitionResult() - .setErrorCode(errorTwo.code()) + topicResult.resultsByPartition().add(new AddPartitionsToTxnPartitionResult() + .setPartitionErrorCode(errorTwo.code()) .setPartitionIndex(partitionTwo)); topicCollection.add(topicResult); + + if (version < 4) { + AddPartitionsToTxnResponseData data = new AddPartitionsToTxnResponseData() + .setResultsByTopicV3AndBelow(topicCollection) + .setThrottleTimeMs(throttleTimeMs); + AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(data); - AddPartitionsToTxnResponseData data = new AddPartitionsToTxnResponseData() - .setResults(topicCollection) - .setThrottleTimeMs(throttleTimeMs); - AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(data); - - for (short version : ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions()) { AddPartitionsToTxnResponse parsedResponse = AddPartitionsToTxnResponse.parse(response.serialize(version), version); assertEquals(expectedErrorCounts, parsedResponse.errorCounts()); assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs()); assertEquals(version >= 1, parsedResponse.shouldClientThrottle(version)); + } else { + AddPartitionsToTxnResultCollection results = new AddPartitionsToTxnResultCollection(); + results.add(new AddPartitionsToTxnResult().setTransactionalId("txn1").setTopicResults(topicCollection)); + + // Create another transaction with new name and errorOne for a single partition. + Map<TopicPartition, Errors> txnTwoExpectedErrors = Collections.singletonMap(tp2, errorOne); + results.add(AddPartitionsToTxnResponse.resultForTransaction("txn2", txnTwoExpectedErrors)); + + AddPartitionsToTxnResponseData data = new AddPartitionsToTxnResponseData() + .setResultsByTransaction(results) + .setThrottleTimeMs(throttleTimeMs); + AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(data); + + Map<Errors, Integer> newExpectedErrorCounts = new HashMap<>(); + newExpectedErrorCounts.put(Errors.NONE, 1); // top level error + newExpectedErrorCounts.put(errorOne, 2); + newExpectedErrorCounts.put(errorTwo, 1); + + AddPartitionsToTxnResponse parsedResponse = AddPartitionsToTxnResponse.parse(response.serialize(version), version); + assertEquals(txnTwoExpectedErrors, errorsForTransaction(response.getTransactionTopicResults("txn2"))); + assertEquals(newExpectedErrorCounts, parsedResponse.errorCounts()); + assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs()); + assertTrue(parsedResponse.shouldClientThrottle(version)); } } + + @Test + public void testBatchedErrors() { + Map<TopicPartition, Errors> txn1Errors = Collections.singletonMap(tp1, errorOne); + Map<TopicPartition, Errors> txn2Errors = Collections.singletonMap(tp1, errorOne); + + AddPartitionsToTxnResult transaction1 = AddPartitionsToTxnResponse.resultForTransaction("txn1", txn1Errors); + AddPartitionsToTxnResult transaction2 = AddPartitionsToTxnResponse.resultForTransaction("txn2", txn2Errors); + + AddPartitionsToTxnResultCollection results = new AddPartitionsToTxnResultCollection(); + results.add(transaction1); + results.add(transaction2); + + AddPartitionsToTxnResponse response = new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setResultsByTransaction(results)); + + assertEquals(txn1Errors, errorsForTransaction(response.getTransactionTopicResults("txn1"))); + assertEquals(txn2Errors, errorsForTransaction(response.getTransactionTopicResults("txn2"))); + } Review Comment: nit: Should we add a test for `errors()`? ########## core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala: ########## @@ -317,6 +322,34 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } } + + def handleVerifyPartitionsInTransaction(transactionalId: String, + producerId: Long, + producerEpoch: Short, + partitions: collection.Set[TopicPartition], + responseCallback: VerifyPartitionsCallback): Unit = { + if (transactionalId == null || transactionalId.isEmpty) { + debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request") + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava)) + } else { + val result: ApiResult[(Int, TransactionMetadata)] = getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions) + + result match { + case Left(err) => + debug(s"Returning $err error code to client for $transactionalId's AddPartitions request") Review Comment: nit: Should update this line to mention that we are validating only here? ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java: ########## @@ -80,45 +105,44 @@ topicName, new AddPartitionsToTxnPartitionResultCollection() AddPartitionsToTxnTopicResultCollection topicCollection = new AddPartitionsToTxnTopicResultCollection(); for (Map.Entry<String, AddPartitionsToTxnPartitionResultCollection> entry : resultMap.entrySet()) { topicCollection.add(new AddPartitionsToTxnTopicResult() - .setName(entry.getKey()) - .setResults(entry.getValue())); + .setName(entry.getKey()) + .setResultsByPartition(entry.getValue())); } - - this.data = new AddPartitionsToTxnResponseData() - .setThrottleTimeMs(throttleTimeMs) - .setResults(topicCollection); + return topicCollection; } - @Override - public int throttleTimeMs() { - return data.throttleTimeMs(); + public static AddPartitionsToTxnResult resultForTransaction(String transactionalId, Map<TopicPartition, Errors> errors) { + return new AddPartitionsToTxnResult().setTransactionalId(transactionalId).setTopicResults(topicCollectionForErrors(errors)); } - @Override - public void maybeSetThrottleTimeMs(int throttleTimeMs) { - data.setThrottleTimeMs(throttleTimeMs); + public AddPartitionsToTxnTopicResultCollection getTransactionTopicResults(String transactionalId) { + return data.resultsByTransaction().find(transactionalId).topicResults(); } - public Map<TopicPartition, Errors> errors() { - if (cachedErrorsMap != null) { - return cachedErrorsMap; - } - - cachedErrorsMap = new HashMap<>(); - - for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) { - for (AddPartitionsToTxnPartitionResult partitionResult : topicResult.results()) { - cachedErrorsMap.put(new TopicPartition( - topicResult.name(), partitionResult.partitionIndex()), - Errors.forCode(partitionResult.errorCode())); + public static Map<TopicPartition, Errors> errorsForTransaction(AddPartitionsToTxnTopicResultCollection topicCollection) { + Map<TopicPartition, Errors> topicResults = new HashMap<>(); + for (AddPartitionsToTxnTopicResult topicResult : topicCollection) { + for (AddPartitionsToTxnPartitionResult partitionResult : topicResult.resultsByPartition()) { + topicResults.put( + new TopicPartition(topicResult.name(), partitionResult.partitionIndex()), Errors.forCode(partitionResult.partitionErrorCode())); } } - return cachedErrorsMap; + return topicResults; } @Override public Map<Errors, Integer> errorCounts() { - return errorCounts(errors().values()); + List<Errors> allErrors = new ArrayList<>(); + + // If we are not using this field, we have request 4 or later + if (this.data.resultsByTopicV3AndBelow().size() == 0) { + allErrors.add(Errors.forCode(data.errorCode())); Review Comment: nit: Should we use `updateErrorCounts` from `AbstractResponse` instead of creating `allErrors`? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,66 +2386,111 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach { result => + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + } + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def addResultAndMaybeSendResponse(result: AddPartitionsToTxnResult): Unit = { + val canSend = responses.synchronized { + responses.add(result) + responses.size() == txns.size() + } + if (canSend) { + requestHelper.sendResponseMaybeThrottle(request, createResponse) + } + } + + txns.forEach { transaction => + val transactionalId = transaction.transactionalId + val partitionsToAdd = partitionsByTransaction.get(transactionalId).asScala + + // Versions < 4 come from clients and must be authorized to write for the given transaction and for the given topics. + if (version < 4 && !authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) { + addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED)) } else { - def sendResponseCallback(error: Errors): Unit = { - def createResponse(requestThrottleMs: Int): AbstractResponse = { - val finalError = - if (addPartitionsToTxnRequest.version < 2 && error == Errors.PRODUCER_FENCED) { + val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() + val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() + val authorizedPartitions = mutable.Set[TopicPartition]() + + // Only request versions less than 4 need write authorization since they come from clients. + val authorizedTopics = + if (version < 4) + authHelper.filterByAuthorized(request.context, WRITE, TOPIC, partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) + else + partitionsToAdd.map(_.topic).toSet + for (topicPartition <- partitionsToAdd) { + if (!authorizedTopics.contains(topicPartition.topic)) + unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED + else if (!metadataCache.contains(topicPartition)) + nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION + else + authorizedPartitions.add(topicPartition) + } + + if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { + // Any failed partition check causes the entire transaction to fail. We send the appropriate error codes for the + // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded + // the authorization check to indicate that they were not added to the transaction. + val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ + authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) + addResultAndMaybeSendResponse(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitionErrors.asJava)) + } else { + def sendResponseCallback(error: Errors): Unit = { + val finalError = { + if (version < 2 && error == Errors.PRODUCER_FENCED) { // For older clients, they could not understand the new PRODUCER_FENCED error code, // so we need to return the old INVALID_PRODUCER_EPOCH to have the same client handling logic. Errors.INVALID_PRODUCER_EPOCH } else { error } - - val responseBody: AddPartitionsToTxnResponse = new AddPartitionsToTxnResponse(requestThrottleMs, - partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava) - trace(s"Completed $transactionalId's AddPartitionsToTxnRequest with partitions $partitionsToAdd: errors: $error from client ${request.header.clientId}") - responseBody + } + addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId, finalError)) } - requestHelper.sendResponseMaybeThrottle(request, createResponse) - } - txnCoordinator.handleAddPartitionsToTransaction(transactionalId, - addPartitionsToTxnRequest.data.producerId, - addPartitionsToTxnRequest.data.producerEpoch, - authorizedPartitions, - sendResponseCallback, - requestLocal) + if (!transaction.verifyOnly) { + txnCoordinator.handleAddPartitionsToTransaction(transactionalId, + transaction.producerId, + transaction.producerEpoch, + authorizedPartitions, + sendResponseCallback, + requestLocal) + } else { + txnCoordinator.handleVerifyPartitionsInTransaction(transactionalId, + transaction.producerId, + transaction.producerEpoch, + authorizedPartitions, + addResultAndMaybeSendResponse) + } + } Review Comment: 👍🏻 ########## core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala: ########## @@ -314,6 +316,32 @@ class TransactionCoordinatorTest { verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) } + @Test + def shouldRespondWithErrorsNoneOnAddPartitionWhenOngoingVerifyOnlyAndPartitionsTheSame(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0))))) + + coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, partitions, verifyPartitionsInTxnCallback) + assertEquals(Errors.NONE, error) + verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId)) + } + + @Test + def shouldRespondWithInvalidTxnStateWhenVerifyOnlyAndPartitionNotPresent(): Unit = { + when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch, + new TransactionMetadata(transactionalId, 0, 0, 0, RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0))))) + Review Comment: nit: Extra empty line. ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java: ########## @@ -103,26 +113,82 @@ public AddPartitionsToTxnRequest(final AddPartitionsToTxnRequestData data, short this.data = data; } - public List<TopicPartition> partitions() { - if (cachedPartitions != null) { - return cachedPartitions; - } - cachedPartitions = Builder.getPartitions(data); - return cachedPartitions; - } - @Override public AddPartitionsToTxnRequestData data() { return data; } @Override public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs, Throwable e) { - final HashMap<TopicPartition, Errors> errors = new HashMap<>(); - for (TopicPartition partition : partitions()) { - errors.put(partition, Errors.forException(e)); + Errors error = Errors.forException(e); + AddPartitionsToTxnResponseData response = new AddPartitionsToTxnResponseData(); + if (version() < 4) { + response.setResultsByTopicV3AndBelow(errorResponseForTopics(data.v3AndBelowTopics(), error)); + } else { + AddPartitionsToTxnResultCollection results = new AddPartitionsToTxnResultCollection(); + response.setResultsByTransaction(results); Review Comment: nit: I think that you can remove those two lines. The collection should be initialized by default. ########## core/src/test/scala/unit/kafka/server/KafkaApisTest.scala: ########## @@ -2030,13 +2032,93 @@ class KafkaApisTest { val response = capturedResponse.getValue if (version < 2) { - assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors()) + assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID)) } else { - assertEquals(Collections.singletonMap(topicPartition, Errors.PRODUCER_FENCED), response.errors()) + assertEquals(Collections.singletonMap(topicPartition, Errors.PRODUCER_FENCED), response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID)) } } } + @Test + def testBatchedRequest(): Unit = { + val topic = "topic" + addTopicToMetadataCache(topic, numPartitions = 2) + + val capturedResponse: ArgumentCaptor[AddPartitionsToTxnResponse] = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnResponse]) + val responseCallback: ArgumentCaptor[Errors => Unit] = ArgumentCaptor.forClass(classOf[Errors => Unit]) + val verifyPartitionsCallback: ArgumentCaptor[AddPartitionsToTxnResult => Unit] = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnResult => Unit]) + + val transactionalId1 = "txnId1" + val transactionalId2 = "txnId2" + val producerId = 15L + val epoch = 0.toShort + + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + + val addPartitionsToTxnRequest = AddPartitionsToTxnRequest.Builder.forBroker( + new AddPartitionsToTxnTransactionCollection( + List(new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId1) + .setProducerId(producerId) + .setProducerEpoch(epoch) + .setVerifyOnly(false) + .setTopics(new AddPartitionsToTxnTopicCollection( + Collections.singletonList(new AddPartitionsToTxnTopic() + .setName(tp0.topic) + .setPartitions(Collections.singletonList(tp0.partition)) + ).iterator()) + ), new AddPartitionsToTxnTransaction() + .setTransactionalId(transactionalId2) + .setProducerId(producerId) + .setProducerEpoch(epoch) + .setVerifyOnly(true) + .setTopics(new AddPartitionsToTxnTopicCollection( + Collections.singletonList(new AddPartitionsToTxnTopic() + .setName(tp1.topic) + .setPartitions(Collections.singletonList(tp1.partition)) + ).iterator()) + ) + ).asJava.iterator() + ) + ).build(4.toShort) + val request = buildRequest(addPartitionsToTxnRequest) + + val requestLocal = RequestLocal.withThreadConfinedCaching + when(txnCoordinator.handleAddPartitionsToTransaction( + ArgumentMatchers.eq(transactionalId1), + ArgumentMatchers.eq(producerId), + ArgumentMatchers.eq(epoch), + ArgumentMatchers.eq(Set(tp0)), + responseCallback.capture(), + ArgumentMatchers.eq(requestLocal) + )).thenAnswer(_ => responseCallback.getValue.apply(Errors.NONE)) + + when(txnCoordinator.handleVerifyPartitionsInTransaction( + ArgumentMatchers.eq(transactionalId2), + ArgumentMatchers.eq(producerId), + ArgumentMatchers.eq(epoch), + ArgumentMatchers.eq(Set(tp1)), + verifyPartitionsCallback.capture(), + )).thenAnswer(_ => verifyPartitionsCallback.getValue.apply(AddPartitionsToTxnResponse.resultForTransaction(transactionalId2, Map(tp1 -> Errors.PRODUCER_FENCED).asJava))) + + createKafkaApis().handleAddPartitionsToTxnRequest(request, requestLocal) + + verify(requestChannel).sendResponse( + ArgumentMatchers.eq(request), + capturedResponse.capture(), + ArgumentMatchers.eq(None) + ) + val response = capturedResponse.getValue Review Comment: nit: You can use `verifyNoThrottling`. ########## clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java: ########## @@ -80,45 +105,44 @@ topicName, new AddPartitionsToTxnPartitionResultCollection() AddPartitionsToTxnTopicResultCollection topicCollection = new AddPartitionsToTxnTopicResultCollection(); for (Map.Entry<String, AddPartitionsToTxnPartitionResultCollection> entry : resultMap.entrySet()) { topicCollection.add(new AddPartitionsToTxnTopicResult() - .setName(entry.getKey()) - .setResults(entry.getValue())); + .setName(entry.getKey()) + .setResultsByPartition(entry.getValue())); } - - this.data = new AddPartitionsToTxnResponseData() - .setThrottleTimeMs(throttleTimeMs) - .setResults(topicCollection); + return topicCollection; } - @Override - public int throttleTimeMs() { - return data.throttleTimeMs(); + public static AddPartitionsToTxnResult resultForTransaction(String transactionalId, Map<TopicPartition, Errors> errors) { + return new AddPartitionsToTxnResult().setTransactionalId(transactionalId).setTopicResults(topicCollectionForErrors(errors)); } - @Override - public void maybeSetThrottleTimeMs(int throttleTimeMs) { - data.setThrottleTimeMs(throttleTimeMs); + public AddPartitionsToTxnTopicResultCollection getTransactionTopicResults(String transactionalId) { + return data.resultsByTransaction().find(transactionalId).topicResults(); } - public Map<TopicPartition, Errors> errors() { - if (cachedErrorsMap != null) { - return cachedErrorsMap; - } - - cachedErrorsMap = new HashMap<>(); - - for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) { - for (AddPartitionsToTxnPartitionResult partitionResult : topicResult.results()) { - cachedErrorsMap.put(new TopicPartition( - topicResult.name(), partitionResult.partitionIndex()), - Errors.forCode(partitionResult.errorCode())); + public static Map<TopicPartition, Errors> errorsForTransaction(AddPartitionsToTxnTopicResultCollection topicCollection) { + Map<TopicPartition, Errors> topicResults = new HashMap<>(); + for (AddPartitionsToTxnTopicResult topicResult : topicCollection) { + for (AddPartitionsToTxnPartitionResult partitionResult : topicResult.resultsByPartition()) { + topicResults.put( + new TopicPartition(topicResult.name(), partitionResult.partitionIndex()), Errors.forCode(partitionResult.partitionErrorCode())); } } - return cachedErrorsMap; + return topicResults; } @Override public Map<Errors, Integer> errorCounts() { - return errorCounts(errors().values()); + List<Errors> allErrors = new ArrayList<>(); + + // If we are not using this field, we have request 4 or later + if (this.data.resultsByTopicV3AndBelow().size() == 0) { Review Comment: nit: `isEmpty()`? ########## core/src/test/scala/unit/kafka/server/KafkaApisTest.scala: ########## @@ -2030,13 +2032,93 @@ class KafkaApisTest { val response = capturedResponse.getValue if (version < 2) { - assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors()) + assertEquals(Collections.singletonMap(topicPartition, Errors.INVALID_PRODUCER_EPOCH), response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID)) } else { - assertEquals(Collections.singletonMap(topicPartition, Errors.PRODUCER_FENCED), response.errors()) + assertEquals(Collections.singletonMap(topicPartition, Errors.PRODUCER_FENCED), response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID)) } } } + @Test + def testBatchedRequest(): Unit = { Review Comment: nit: Could we include the related api in the name? ########## core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala: ########## @@ -317,6 +322,34 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } } } + + def handleVerifyPartitionsInTransaction(transactionalId: String, + producerId: Long, + producerEpoch: Short, + partitions: collection.Set[TopicPartition], + responseCallback: VerifyPartitionsCallback): Unit = { + if (transactionalId == null || transactionalId.isEmpty) { + debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for $transactionalId's AddPartitions request") + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava)) + } else { + val result: ApiResult[(Int, TransactionMetadata)] = getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions) + + result match { + case Left(err) => + debug(s"Returning $err error code to client for $transactionalId's AddPartitions request") + responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId, partitions.map(_ -> err).toMap.asJava)) + + case Right((_, txnMetadata)) => + val txnMetadataPartitions = txnMetadata.topicPartitions + val addedPartitions = partitions.intersect(txnMetadataPartitions) + val nonAddedPartitions = partitions.diff(txnMetadataPartitions) + val errors = mutable.Map[TopicPartition, Errors]() + addedPartitions.foreach(errors.put(_, Errors.NONE)) + nonAddedPartitions.foreach(errors.put(_, Errors.INVALID_TXN_STATE)) Review Comment: nit: I am not sure if it makes a real difference but did you consider doing something like this: ``` partitions.foreach { tp => if (txnMetadata.topicPartitions.contains(tp)) ... else ... } ``` If works, it would avoid allocating the intermediate collections. I leave this up to you. ########## clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java: ########## @@ -2598,12 +2604,37 @@ private OffsetsForLeaderEpochResponse createLeaderEpochResponse() { } private AddPartitionsToTxnRequest createAddPartitionsToTxnRequest(short version) { - return new AddPartitionsToTxnRequest.Builder("tid", 21L, (short) 42, - singletonList(new TopicPartition("topic", 73))).build(version); + if (version < 4) { + return AddPartitionsToTxnRequest.Builder.forClient("tid", 21L, (short) 42, + singletonList(new TopicPartition("topic", 73))).build(version); + } else { + AddPartitionsToTxnTransactionCollection transactions = new AddPartitionsToTxnTransactionCollection( + singletonList(new AddPartitionsToTxnTransaction() Review Comment: nit: Indentation of this line looks weird. Should it be on the previous line? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -2384,66 +2386,111 @@ class KafkaApis(val requestChannel: RequestChannel, if (config.interBrokerProtocolVersion.isLessThan(version)) throw new UnsupportedVersionException(s"inter.broker.protocol.version: ${config.interBrokerProtocolVersion.version} is less than the required version: ${version.version}") } - - def handleAddPartitionToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { ensureInterBrokerVersion(IBP_0_11_0_IV0) - val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest] - val transactionalId = addPartitionsToTxnRequest.data.transactionalId - val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala - if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, transactionalId)) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception)) - else { - val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]() - val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]() - val authorizedPartitions = mutable.Set[TopicPartition]() - - val authorizedTopics = authHelper.filterByAuthorized(request.context, WRITE, TOPIC, - partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) - for (topicPartition <- partitionsToAdd) { - if (!authorizedTopics.contains(topicPartition.topic)) - unauthorizedTopicErrors += topicPartition -> Errors.TOPIC_AUTHORIZATION_FAILED - else if (!metadataCache.contains(topicPartition)) - nonExistingTopicErrors += topicPartition -> Errors.UNKNOWN_TOPIC_OR_PARTITION - else - authorizedPartitions.add(topicPartition) + val addPartitionsToTxnRequest = + if (request.context.apiVersion() < 4) + request.body[AddPartitionsToTxnRequest].normalizeRequest() + else + request.body[AddPartitionsToTxnRequest] + val version = addPartitionsToTxnRequest.version + val responses = new AddPartitionsToTxnResultCollection() + val partitionsByTransaction = addPartitionsToTxnRequest.partitionsByTransaction() + + // Newer versions of the request should only come from other brokers. + if (version >= 4) authHelper.authorizeClusterOperation(request, CLUSTER_ACTION) + + // V4 requests introduced batches of transactions. We need all transactions to be handled before sending the + // response so there are a few differences in handling errors and sending responses. + def createResponse(requestThrottleMs: Int): AbstractResponse = { + if (version < 4) { + // There will only be one response in data. Add it to the response data object. + val data = new AddPartitionsToTxnResponseData() + responses.forEach { result => + data.setResultsByTopicV3AndBelow(result.topicResults()) + data.setThrottleTimeMs(requestThrottleMs) + } + new AddPartitionsToTxnResponse(data) + } else { + new AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses)) } + } - if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) { - // Any failed partition check causes the entire request to fail. We send the appropriate error codes for the - // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error code for the partitions which succeeded - // the authorization check to indicate that they were not added to the transaction. - val partitionErrors = unauthorizedTopicErrors ++ nonExistingTopicErrors ++ - authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED) - requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs => - new AddPartitionsToTxnResponse(requestThrottleMs, partitionErrors.asJava)) + val txns = addPartitionsToTxnRequest.data.transactions + def addResultAndMaybeSendResponse(result: AddPartitionsToTxnResult): Unit = { + val canSend = responses.synchronized { + responses.add(result) + responses.size() == txns.size() Review Comment: nit: You can remove the `()`. ########## core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala: ########## @@ -330,44 +363,53 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } else { // try to update the transaction metadata and append the updated metadata to txn log; // if there is no such metadata treat it as invalid producerId mapping error. - val result: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { - case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) - - case Some(epochAndMetadata) => - val coordinatorEpoch = epochAndMetadata.coordinatorEpoch - val txnMetadata = epochAndMetadata.transactionMetadata - - // generate the new transaction metadata with added partitions - txnMetadata.inLock { - if (txnMetadata.producerId != producerId) { - Left(Errors.INVALID_PRODUCER_ID_MAPPING) - } else if (txnMetadata.producerEpoch != producerEpoch) { - Left(Errors.PRODUCER_FENCED) - } else if (txnMetadata.pendingTransitionInProgress) { - // return a retriable exception to let the client backoff and retry - Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) { - Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == Ongoing && partitions.subsetOf(txnMetadata.topicPartitions)) { - // this is an optimization: if the partitions are already in the metadata reply OK immediately - Left(Errors.NONE) - } else { - Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds())) - } - } - } + val result: ApiResult[(Int, TransactionMetadata)] = getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions) result match { case Left(err) => debug(s"Returning $err error code to client for $transactionalId's AddPartitions request") responseCallback(err) - case Right((coordinatorEpoch, newMetadata)) => - txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, + case Right((coordinatorEpoch, txnMetadata)) => + txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()), responseCallback, requestLocal = requestLocal) } } } + + private def getTransactionMetadata(transactionalId: String, + producerId: Long, + producerEpoch: Short, + partitions: collection.Set[TopicPartition]): ApiResult[(Int, TransactionMetadata)] = { Review Comment: nit: Indentation is off. ########## core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala: ########## @@ -330,44 +363,53 @@ class TransactionCoordinator(txnConfig: TransactionConfig, } else { // try to update the transaction metadata and append the updated metadata to txn log; // if there is no such metadata treat it as invalid producerId mapping error. - val result: ApiResult[(Int, TxnTransitMetadata)] = txnManager.getTransactionState(transactionalId).flatMap { - case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING) - - case Some(epochAndMetadata) => - val coordinatorEpoch = epochAndMetadata.coordinatorEpoch - val txnMetadata = epochAndMetadata.transactionMetadata - - // generate the new transaction metadata with added partitions - txnMetadata.inLock { - if (txnMetadata.producerId != producerId) { - Left(Errors.INVALID_PRODUCER_ID_MAPPING) - } else if (txnMetadata.producerEpoch != producerEpoch) { - Left(Errors.PRODUCER_FENCED) - } else if (txnMetadata.pendingTransitionInProgress) { - // return a retriable exception to let the client backoff and retry - Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == PrepareCommit || txnMetadata.state == PrepareAbort) { - Left(Errors.CONCURRENT_TRANSACTIONS) - } else if (txnMetadata.state == Ongoing && partitions.subsetOf(txnMetadata.topicPartitions)) { - // this is an optimization: if the partitions are already in the metadata reply OK immediately - Left(Errors.NONE) - } else { - Right(coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds())) - } - } - } + val result: ApiResult[(Int, TransactionMetadata)] = getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions) result match { case Left(err) => debug(s"Returning $err error code to client for $transactionalId's AddPartitions request") responseCallback(err) - case Right((coordinatorEpoch, newMetadata)) => - txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, newMetadata, + case Right((coordinatorEpoch, txnMetadata)) => + txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()), responseCallback, requestLocal = requestLocal) } } } + + private def getTransactionMetadata(transactionalId: String, + producerId: Long, + producerEpoch: Short, + partitions: collection.Set[TopicPartition]): ApiResult[(Int, TransactionMetadata)] = { + // try to update the transaction metadata and append the updated metadata to txn log; + // if there is no such metadata treat it as invalid producerId mapping error. Review Comment: nit: Is this comment relevant here? ########## core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala: ########## @@ -55,22 +65,146 @@ class AddPartitionsToTxnRequestServerTest extends BaseRequestTest { val producerId = 1000L val producerEpoch: Short = 0 - val request = new AddPartitionsToTxnRequest.Builder( - transactionalId, - producerId, - producerEpoch, - List(createdTopicPartition, nonExistentTopic).asJava) - .build() + val request = + if (version < 4) { + AddPartitionsToTxnRequest.Builder.forClient( + transactionalId, + producerId, + producerEpoch, + List(createdTopicPartition, nonExistentTopic).asJava + ).build() Review Comment: Should we set the version here? ########## core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala: ########## @@ -1185,4 +1213,8 @@ class TransactionCoordinatorTest { def errorsCallback(ret: Errors): Unit = { error = ret } + + def verifyPartitionsInTxnCallback(result: AddPartitionsToTxnResult): Unit = { + errors = AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap Review Comment: Relying on a global variable is risky here. It would be much better to define the callback within the test itself. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org