This is an automated email from the ASF dual-hosted git repository.

jolshan pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new 29a1a16668d KAFKA-14402: Update AddPartitionsToTxn protocol to batch 
and handle verifyOnly requests (#13231)
29a1a16668d is described below

commit 29a1a16668d76a1cc04ec9e39ea13026f2dce1de
Author: Justine Olshan <[email protected]>
AuthorDate: Tue Mar 7 09:20:16 2023 -0800

    KAFKA-14402: Update AddPartitionsToTxn protocol to batch and handle 
verifyOnly requests (#13231)
    
    Part 1 of KIP-890
    
    I've updated the API spec and related classes.
    
    Clients should only be able to send up to version 3 requests and that is 
enforced by using a client builder.
    
    Requests > 4 only require cluster permissions as they are initiated from 
other brokers. API version 4 is marked as unstable for now.
    
    I've added tests for the batched requests and for the verifyOnly mode.
    
    Also -- minor change to the KafkaApis method to properly match the request 
name.
    
    Reviewers: Jason Gustafson <[email protected]>, Jeff Kim 
<[email protected]>, Guozhang Wang <[email protected]>, David Jacot 
<[email protected]>
---
 .../producer/internals/TransactionManager.java     |   4 +-
 .../common/requests/AddPartitionsToTxnRequest.java | 142 ++++++++++++-----
 .../requests/AddPartitionsToTxnResponse.java       |  90 +++++++----
 .../common/message/AddPartitionsToTxnRequest.json  |  37 ++++-
 .../common/message/AddPartitionsToTxnResponse.json |  35 +++--
 .../clients/producer/internals/SenderTest.java     |  26 +--
 .../producer/internals/TransactionManagerTest.java |  38 +++--
 .../apache/kafka/common/message/MessageTest.java   |  31 +++-
 .../requests/AddPartitionsToTxnRequestTest.java    | 128 +++++++++++++--
 .../requests/AddPartitionsToTxnResponseTest.java   |  91 ++++++++---
 .../kafka/common/requests/RequestResponseTest.java |  45 +++++-
 .../transaction/TransactionCoordinator.scala       |  99 ++++++++----
 core/src/main/scala/kafka/server/KafkaApis.scala   | 143 +++++++++++------
 .../kafka/api/AuthorizerIntegrationTest.scala      |   4 +-
 .../TransactionCoordinatorConcurrencyTest.scala    |   3 +
 .../transaction/TransactionCoordinatorTest.scala   |  38 ++++-
 .../AddPartitionsToTxnRequestServerTest.scala      | 174 ++++++++++++++++++---
 .../scala/unit/kafka/server/KafkaApisTest.scala    |  96 ++++++++++--
 .../scala/unit/kafka/server/RequestQuotaTest.scala |   2 +-
 .../test/scala/unit/kafka/utils/TestUtils.scala    |   1 -
 20 files changed, 955 insertions(+), 272 deletions(-)

diff --git 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
index de5a6ced41c..a41792ada07 100644
--- 
a/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
+++ 
b/clients/src/main/java/org/apache/kafka/clients/producer/internals/TransactionManager.java
@@ -1052,7 +1052,7 @@ public class TransactionManager {
         pendingPartitionsInTransaction.addAll(newPartitionsInTransaction);
         newPartitionsInTransaction.clear();
         AddPartitionsToTxnRequest.Builder builder =
-            new AddPartitionsToTxnRequest.Builder(transactionalId,
+            AddPartitionsToTxnRequest.Builder.forClient(transactionalId,
                 producerIdAndEpoch.producerId,
                 producerIdAndEpoch.epoch,
                 new ArrayList<>(pendingPartitionsInTransaction));
@@ -1328,7 +1328,7 @@ public class TransactionManager {
         @Override
         public void handleResponse(AbstractResponse response) {
             AddPartitionsToTxnResponse addPartitionsToTxnResponse = 
(AddPartitionsToTxnResponse) response;
-            Map<TopicPartition, Errors> errors = 
addPartitionsToTxnResponse.errors();
+            Map<TopicPartition, Errors> errors = 
addPartitionsToTxnResponse.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID);
             boolean hasPartitionErrors = false;
             Set<String> unauthorizedTopics = new HashSet<>();
             retryBackoffMs = TransactionManager.this.retryBackoffMs;
diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
index 1034c0f7adc..c91374fc507 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java
@@ -19,7 +19,15 @@ package org.apache.kafka.common.requests;
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.message.AddPartitionsToTxnRequestData;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransaction;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransactionCollection;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection;
+import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResult;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResultCollection;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResult;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResultCollection;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.ByteBufferAccessor;
 import org.apache.kafka.common.protocol.Errors;
@@ -34,22 +42,37 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
 
     private final AddPartitionsToTxnRequestData data;
 
-    private List<TopicPartition> cachedPartitions = null;
-
     public static class Builder extends 
AbstractRequest.Builder<AddPartitionsToTxnRequest> {
         public final AddPartitionsToTxnRequestData data;
+        
+        public static Builder forClient(String transactionalId,
+                                        long producerId,
+                                        short producerEpoch,
+                                        List<TopicPartition> partitions) {
+
+            AddPartitionsToTxnTopicCollection topics = 
buildTxnTopicCollection(partitions);
+            
+            return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(), 
(short) 3,
+                new AddPartitionsToTxnRequestData()
+                    .setV3AndBelowTransactionalId(transactionalId)
+                    .setV3AndBelowProducerId(producerId)
+                    .setV3AndBelowProducerEpoch(producerEpoch)
+                    .setV3AndBelowTopics(topics));
+        }
+        
+        public static Builder 
forBroker(AddPartitionsToTxnTransactionCollection transactions) {
+            return new Builder((short) 4, 
ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion(),
+                new AddPartitionsToTxnRequestData()
+                    .setTransactions(transactions));
+        }
+        
+        private Builder(short minVersion, short maxVersion, 
AddPartitionsToTxnRequestData data) {
+            super(ApiKeys.ADD_PARTITIONS_TO_TXN, minVersion, maxVersion);
 
-        public Builder(final AddPartitionsToTxnRequestData data) {
-            super(ApiKeys.ADD_PARTITIONS_TO_TXN);
             this.data = data;
         }
 
-        public Builder(final String transactionalId,
-                       final long producerId,
-                       final short producerEpoch,
-                       final List<TopicPartition> partitions) {
-            super(ApiKeys.ADD_PARTITIONS_TO_TXN);
-
+        private static AddPartitionsToTxnTopicCollection 
buildTxnTopicCollection(final List<TopicPartition> partitions) {
             Map<String, List<Integer>> partitionMap = new HashMap<>();
             for (TopicPartition topicPartition : partitions) {
                 String topicName = topicPartition.topic();
@@ -66,15 +89,10 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
             AddPartitionsToTxnTopicCollection topics = new 
AddPartitionsToTxnTopicCollection();
             for (Map.Entry<String, List<Integer>> partitionEntry : 
partitionMap.entrySet()) {
                 topics.add(new AddPartitionsToTxnTopic()
-                               .setName(partitionEntry.getKey())
-                               .setPartitions(partitionEntry.getValue()));
+                    .setName(partitionEntry.getKey())
+                    .setPartitions(partitionEntry.getValue()));
             }
-
-            this.data = new AddPartitionsToTxnRequestData()
-                            .setTransactionalId(transactionalId)
-                            .setProducerId(producerId)
-                            .setProducerEpoch(producerEpoch)
-                            .setTopics(topics);
+            return topics;
         }
 
         @Override
@@ -82,16 +100,6 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
             return new AddPartitionsToTxnRequest(data, version);
         }
 
-        static List<TopicPartition> 
getPartitions(AddPartitionsToTxnRequestData data) {
-            List<TopicPartition> partitions = new ArrayList<>();
-            for (AddPartitionsToTxnTopic topicCollection : data.topics()) {
-                for (Integer partition : topicCollection.partitions()) {
-                    partitions.add(new TopicPartition(topicCollection.name(), 
partition));
-                }
-            }
-            return partitions;
-        }
-
         @Override
         public String toString() {
             return data.toString();
@@ -103,14 +111,6 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
         this.data = data;
     }
 
-    public List<TopicPartition> partitions() {
-        if (cachedPartitions != null) {
-            return cachedPartitions;
-        }
-        cachedPartitions = Builder.getPartitions(data);
-        return cachedPartitions;
-    }
-
     @Override
     public AddPartitionsToTxnRequestData data() {
         return data;
@@ -118,11 +118,73 @@ public class AddPartitionsToTxnRequest extends 
AbstractRequest {
 
     @Override
     public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs, 
Throwable e) {
-        final HashMap<TopicPartition, Errors> errors = new HashMap<>();
-        for (TopicPartition partition : partitions()) {
-            errors.put(partition, Errors.forException(e));
+        Errors error = Errors.forException(e);
+        AddPartitionsToTxnResponseData response = new 
AddPartitionsToTxnResponseData();
+        if (version() < 4) {
+            
response.setResultsByTopicV3AndBelow(errorResponseForTopics(data.v3AndBelowTopics(),
 error));
+        } else {
+            response.setErrorCode(error.code());
+        }
+        response.setThrottleTimeMs(throttleTimeMs);
+        return new AddPartitionsToTxnResponse(response);
+    }
+
+    public static List<TopicPartition> 
getPartitions(AddPartitionsToTxnTopicCollection topics) {
+        List<TopicPartition> partitions = new ArrayList<>();
+
+        for (AddPartitionsToTxnTopic topicCollection : topics) {
+            for (Integer partition : topicCollection.partitions()) {
+                partitions.add(new TopicPartition(topicCollection.name(), 
partition));
+            }
+        }
+        return partitions;
+    }
+
+    public Map<String, List<TopicPartition>> partitionsByTransaction() {
+        Map<String, List<TopicPartition>> partitionsByTransaction = new 
HashMap<>();
+        for (AddPartitionsToTxnTransaction transaction : data.transactions()) {
+            List<TopicPartition> partitions = 
getPartitions(transaction.topics());
+            partitionsByTransaction.put(transaction.transactionalId(), 
partitions);
+        }
+        return partitionsByTransaction;
+    }
+
+    // Takes a version 3 or below request and returns a v4+ singleton (one 
transaction ID) request.
+    public AddPartitionsToTxnRequest normalizeRequest() {
+        return new AddPartitionsToTxnRequest(new 
AddPartitionsToTxnRequestData().setTransactions(singletonTransaction()), 
version());
+    }
+
+    private AddPartitionsToTxnTransactionCollection singletonTransaction() {
+        AddPartitionsToTxnTransactionCollection singleTxn = new 
AddPartitionsToTxnTransactionCollection();
+        singleTxn.add(new AddPartitionsToTxnTransaction()
+            .setTransactionalId(data.v3AndBelowTransactionalId())
+            .setProducerId(data.v3AndBelowProducerId())
+            .setProducerEpoch(data.v3AndBelowProducerEpoch())
+            .setTopics(data.v3AndBelowTopics()));
+        return singleTxn;
+    }
+    
+    public AddPartitionsToTxnResult errorResponseForTransaction(String 
transactionalId, Errors e) {
+        AddPartitionsToTxnResult txnResult = new 
AddPartitionsToTxnResult().setTransactionalId(transactionalId);
+        AddPartitionsToTxnTopicResultCollection topicResults = 
errorResponseForTopics(data.transactions().find(transactionalId).topics(), e);
+        txnResult.setTopicResults(topicResults);
+        return txnResult;
+    }
+    
+    private AddPartitionsToTxnTopicResultCollection 
errorResponseForTopics(AddPartitionsToTxnTopicCollection topics, Errors e) {
+        AddPartitionsToTxnTopicResultCollection topicResults = new 
AddPartitionsToTxnTopicResultCollection();
+        for (AddPartitionsToTxnTopic topic : topics) {
+            AddPartitionsToTxnTopicResult topicResult = new 
AddPartitionsToTxnTopicResult().setName(topic.name());
+            AddPartitionsToTxnPartitionResultCollection partitionResult = new 
AddPartitionsToTxnPartitionResultCollection();
+            for (Integer partition : topic.partitions()) {
+                partitionResult.add(new AddPartitionsToTxnPartitionResult()
+                    .setPartitionIndex(partition)
+                    .setPartitionErrorCode(e.code()));
+            }
+            topicResult.setResultsByPartition(partitionResult);
+            topicResults.add(topicResult);
         }
-        return new AddPartitionsToTxnResponse(throttleTimeMs, errors);
+        return topicResults;
     }
 
     public static AddPartitionsToTxnRequest parse(ByteBuffer buffer, short 
version) {
diff --git 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java
 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java
index 8038f4b8fc6..645a03038a8 100644
--- 
a/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java
+++ 
b/clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java
@@ -18,6 +18,7 @@ package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResult;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResultCollection;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResult;
@@ -27,7 +28,9 @@ import org.apache.kafka.common.protocol.ByteBufferAccessor;
 import org.apache.kafka.common.protocol.Errors;
 
 import java.nio.ByteBuffer;
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 /**
@@ -48,16 +51,38 @@ public class AddPartitionsToTxnResponse extends 
AbstractResponse {
 
     private final AddPartitionsToTxnResponseData data;
 
-    private Map<TopicPartition, Errors> cachedErrorsMap = null;
+    public static final String V3_AND_BELOW_TXN_ID = "";
 
     public AddPartitionsToTxnResponse(AddPartitionsToTxnResponseData data) {
         super(ApiKeys.ADD_PARTITIONS_TO_TXN);
         this.data = data;
     }
 
-    public AddPartitionsToTxnResponse(int throttleTimeMs, Map<TopicPartition, 
Errors> errors) {
-        super(ApiKeys.ADD_PARTITIONS_TO_TXN);
+    @Override
+    public int throttleTimeMs() {
+        return data.throttleTimeMs();
+    }
+
+    @Override
+    public void maybeSetThrottleTimeMs(int throttleTimeMs) {
+        data.setThrottleTimeMs(throttleTimeMs);
+    }
+
+    public Map<String, Map<TopicPartition, Errors>> errors() {
+        Map<String, Map<TopicPartition, Errors>> errorsMap = new HashMap<>();
+
+        if (!this.data.resultsByTopicV3AndBelow().isEmpty()) {
+            errorsMap.put(V3_AND_BELOW_TXN_ID, 
errorsForTransaction(this.data.resultsByTopicV3AndBelow()));
+        }
+
+        for (AddPartitionsToTxnResult result : 
this.data.resultsByTransaction()) {
+            errorsMap.put(result.transactionalId(), 
errorsForTransaction(result.topicResults()));
+        }
+        
+        return errorsMap;
+    }
 
+    private static AddPartitionsToTxnTopicResultCollection 
topicCollectionForErrors(Map<TopicPartition, Errors> errors) {
         Map<String, AddPartitionsToTxnPartitionResultCollection> resultMap = 
new HashMap<>();
 
         for (Map.Entry<TopicPartition, Errors> entry : errors.entrySet()) {
@@ -65,12 +90,12 @@ public class AddPartitionsToTxnResponse extends 
AbstractResponse {
             String topicName = topicPartition.topic();
 
             AddPartitionsToTxnPartitionResult partitionResult =
-                new AddPartitionsToTxnPartitionResult()
-                    .setErrorCode(entry.getValue().code())
-                    .setPartitionIndex(topicPartition.partition());
+                    new AddPartitionsToTxnPartitionResult()
+                        .setPartitionErrorCode(entry.getValue().code())
+                        .setPartitionIndex(topicPartition.partition());
 
             AddPartitionsToTxnPartitionResultCollection 
partitionResultCollection = resultMap.getOrDefault(
-                topicName, new AddPartitionsToTxnPartitionResultCollection()
+                    topicName, new 
AddPartitionsToTxnPartitionResultCollection()
             );
 
             partitionResultCollection.add(partitionResult);
@@ -80,45 +105,44 @@ public class AddPartitionsToTxnResponse extends 
AbstractResponse {
         AddPartitionsToTxnTopicResultCollection topicCollection = new 
AddPartitionsToTxnTopicResultCollection();
         for (Map.Entry<String, AddPartitionsToTxnPartitionResultCollection> 
entry : resultMap.entrySet()) {
             topicCollection.add(new AddPartitionsToTxnTopicResult()
-                                    .setName(entry.getKey())
-                                    .setResults(entry.getValue()));
+                .setName(entry.getKey())
+                .setResultsByPartition(entry.getValue()));
         }
-
-        this.data = new AddPartitionsToTxnResponseData()
-                        .setThrottleTimeMs(throttleTimeMs)
-                        .setResults(topicCollection);
+        return topicCollection;
     }
 
-    @Override
-    public int throttleTimeMs() {
-        return data.throttleTimeMs();
+    public static AddPartitionsToTxnResult resultForTransaction(String 
transactionalId, Map<TopicPartition, Errors> errors) {
+        return new 
AddPartitionsToTxnResult().setTransactionalId(transactionalId).setTopicResults(topicCollectionForErrors(errors));
     }
 
-    @Override
-    public void maybeSetThrottleTimeMs(int throttleTimeMs) {
-        data.setThrottleTimeMs(throttleTimeMs);
+    public AddPartitionsToTxnTopicResultCollection 
getTransactionTopicResults(String transactionalId) {
+        return 
data.resultsByTransaction().find(transactionalId).topicResults();
     }
 
-    public Map<TopicPartition, Errors> errors() {
-        if (cachedErrorsMap != null) {
-            return cachedErrorsMap;
-        }
-
-        cachedErrorsMap = new HashMap<>();
-
-        for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) {
-            for (AddPartitionsToTxnPartitionResult partitionResult : 
topicResult.results()) {
-                cachedErrorsMap.put(new TopicPartition(
-                        topicResult.name(), partitionResult.partitionIndex()),
-                    Errors.forCode(partitionResult.errorCode()));
+    public static Map<TopicPartition, Errors> 
errorsForTransaction(AddPartitionsToTxnTopicResultCollection topicCollection) {
+        Map<TopicPartition, Errors> topicResults = new HashMap<>();
+        for (AddPartitionsToTxnTopicResult topicResult : topicCollection) {
+            for (AddPartitionsToTxnPartitionResult partitionResult : 
topicResult.resultsByPartition()) {
+                topicResults.put(
+                    new TopicPartition(topicResult.name(), 
partitionResult.partitionIndex()), 
Errors.forCode(partitionResult.partitionErrorCode()));
             }
         }
-        return cachedErrorsMap;
+        return topicResults;
     }
 
     @Override
     public Map<Errors, Integer> errorCounts() {
-        return errorCounts(errors().values());
+        List<Errors> allErrors = new ArrayList<>();
+
+        // If we are not using this field, we have request 4 or later
+        if (this.data.resultsByTopicV3AndBelow().isEmpty()) {
+            allErrors.add(Errors.forCode(data.errorCode()));
+        }
+        
+        errors().forEach((txnId, errors) -> 
+            allErrors.addAll(errors.values())
+        );
+        return errorCounts(allErrors);
     }
 
     @Override
diff --git 
a/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json 
b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json
index 4920da176c7..32bb9b8d1f7 100644
--- a/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json
+++ b/clients/src/main/resources/common/message/AddPartitionsToTxnRequest.json
@@ -23,17 +23,40 @@
   // Version 2 adds the support for new error code PRODUCER_FENCED.
   //
   // Version 3 enables flexible versions.
-  "validVersions": "0-3",
+  //
+  // Version 4 adds VerifyOnly field to check if partitions are already in 
transaction and adds support to batch multiple transactions.
+  // Versions 3 and below will be exclusively used by clients and versions 4 
and above will be used by brokers.
+  // The AddPartitionsToTxnRequest version 4 API is added as part of KIP-890 
and is still
+  // under developement. Hence, the API is not exposed by default by brokers
+  // unless explicitely enabled.
+  "latestVersionUnstable": true,
+  "validVersions": "0-4",
   "flexibleVersions": "3+",
   "fields": [
-    { "name": "TransactionalId", "type": "string", "versions": "0+", 
"entityType": "transactionalId",
-      "about": "The transactional id corresponding to the transaction."},
-    { "name": "ProducerId", "type": "int64", "versions": "0+", "entityType": 
"producerId",
+    { "name": "Transactions", "type": "[]AddPartitionsToTxnTransaction", 
"versions":  "4+",
+      "about": "List of transactions to add partitions to.", "fields": [
+      { "name": "TransactionalId", "type": "string", "versions": "4+", 
"mapKey": true, "entityType": "transactionalId",
+        "about": "The transactional id corresponding to the transaction." },
+      { "name": "ProducerId", "type": "int64", "versions": "4+", "entityType": 
"producerId",
+        "about": "Current producer id in use by the transactional id." },
+      { "name": "ProducerEpoch", "type": "int16", "versions": "4+",
+        "about": "Current epoch associated with the producer id." },
+      { "name": "VerifyOnly", "type": "bool", "versions": "4+", "default": 
false,
+        "about": "Boolean to signify if we want to check if the partition is 
in the transaction rather than add it." },
+      { "name": "Topics", "type": "[]AddPartitionsToTxnTopic", "versions": 
"4+",
+        "about": "The partitions to add to the transaction." }
+    ]},
+    { "name": "V3AndBelowTransactionalId", "type": "string", "versions": 
"0-3", "entityType": "transactionalId",
+      "about": "The transactional id corresponding to the transaction." },
+    { "name": "V3AndBelowProducerId", "type": "int64", "versions": "0-3", 
"entityType": "producerId",
       "about": "Current producer id in use by the transactional id." },
-    { "name": "ProducerEpoch", "type": "int16", "versions": "0+",
+    { "name": "V3AndBelowProducerEpoch", "type": "int16", "versions": "0-3",
       "about": "Current epoch associated with the producer id." },
-    { "name": "Topics", "type": "[]AddPartitionsToTxnTopic", "versions": "0+",
-      "about": "The partitions to add to the transaction.", "fields": [
+    { "name": "V3AndBelowTopics", "type": "[]AddPartitionsToTxnTopic", 
"versions": "0-3",
+      "about": "The partitions to add to the transaction." }
+  ],
+  "commonStructs": [
+    { "name": "AddPartitionsToTxnTopic", "versions": "0+", "fields": [
       { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, 
"entityType": "topicName",
         "about": "The name of the topic." },
       { "name": "Partitions", "type": "[]int32", "versions": "0+",
diff --git 
a/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json 
b/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json
index 4241dc77b4a..326b4acdb44 100644
--- a/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json
+++ b/clients/src/main/resources/common/message/AddPartitionsToTxnResponse.json
@@ -22,22 +22,37 @@
   // Version 2 adds the support for new error code PRODUCER_FENCED.
   //
   // Version 3 enables flexible versions.
-  "validVersions": "0-3",
+  //
+  // Version 4 adds support to batch multiple transactions and a top level 
error code.
+  "validVersions": "0-4",
   "flexibleVersions": "3+",
   "fields": [
     { "name": "ThrottleTimeMs", "type": "int32", "versions": "0+",
       "about": "Duration in milliseconds for which the request was throttled 
due to a quota violation, or zero if the request did not violate any quota." },
-    { "name": "Results", "type": "[]AddPartitionsToTxnTopicResult", 
"versions": "0+",
-      "about": "The results for each topic.", "fields": [
+    { "name": "ErrorCode", "type": "int16", "versions": "4+", "ignorable": 
true,
+      "about": "The response top level error code." },
+    { "name": "ResultsByTransaction", "type": "[]AddPartitionsToTxnResult", 
"versions": "4+",
+      "about": "Results categorized by transactional ID.", "fields": [
+      { "name": "TransactionalId", "type": "string", "versions": "4+", 
"mapKey": true, "entityType": "transactionalId",
+        "about": "The transactional id corresponding to the transaction." },
+      { "name": "TopicResults", "type": "[]AddPartitionsToTxnTopicResult", 
"versions": "4+",
+        "about": "The results for each topic." }
+    ]},
+    { "name": "ResultsByTopicV3AndBelow", "type": 
"[]AddPartitionsToTxnTopicResult", "versions": "0-3",
+      "about": "The results for each topic." }
+  ],
+  "commonStructs": [
+    { "name": "AddPartitionsToTxnTopicResult", "versions": "0+", "fields": [
       { "name": "Name", "type": "string", "versions": "0+", "mapKey": true, 
"entityType": "topicName",
         "about": "The topic name." },
-      { "name": "Results", "type": "[]AddPartitionsToTxnPartitionResult", 
"versions": "0+", 
-        "about": "The results for each partition", "fields": [
-        { "name": "PartitionIndex", "type": "int32", "versions": "0+", 
"mapKey": true,
-          "about": "The partition indexes." },
-        { "name": "ErrorCode", "type": "int16", "versions": "0+",
-          "about": "The response error code."}
-      ]}
+      { "name": "ResultsByPartition", "type": 
"[]AddPartitionsToTxnPartitionResult", "versions": "0+",
+        "about": "The results for each partition" }
+    ]},
+    { "name": "AddPartitionsToTxnPartitionResult", "versions": "0+", "fields": 
[
+      { "name": "PartitionIndex", "type": "int32", "versions": "0+", "mapKey": 
true,
+        "about": "The partition indexes." },
+      { "name": "PartitionErrorCode", "type": "int16", "versions": "0+",
+        "about": "The response error code." }
     ]}
   ]
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
index bdbc1bd92e9..adee050da42 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/SenderTest.java
@@ -40,6 +40,7 @@ import 
org.apache.kafka.common.errors.TransactionAbortedException;
 import org.apache.kafka.common.errors.UnsupportedForMessageFormatException;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
+import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
 import org.apache.kafka.common.message.ApiMessageType;
 import org.apache.kafka.common.message.EndTxnResponseData;
 import org.apache.kafka.common.message.InitProducerIdResponseData;
@@ -1542,7 +1543,7 @@ public class SenderTest {
 
         txnManager.beginTransaction();
         txnManager.maybeAddPartition(tp0);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp0, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce();
 
         // Send first ProduceRequest
@@ -1828,7 +1829,7 @@ public class SenderTest {
 
         transactionManager.beginTransaction();
         transactionManager.maybeAddPartition(tp0);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp0, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce(); // Receive AddPartitions response
 
         assertEquals(0, transactionManager.sequenceNumber(tp0).longValue());
@@ -2384,7 +2385,7 @@ public class SenderTest {
 
         txnManager.beginTransaction();
         txnManager.maybeAddPartition(tp);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp, Errors.NONE)));
         sender.runOnce();
 
         testSplitBatchAndSend(txnManager, producerIdAndEpoch, tp);
@@ -2731,7 +2732,7 @@ public class SenderTest {
 
             txnManager.beginTransaction();
             txnManager.maybeAddPartition(tp);
-            client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp, Errors.NONE)));
+            client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
             sender.initiateClose();
             txnManager.beginCommit();
@@ -2851,7 +2852,7 @@ public class SenderTest {
 
     private void addPartitionToTxn(Sender sender, TransactionManager 
txnManager, TopicPartition tp) {
         txnManager.maybeAddPartition(tp);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp, Errors.NONE)));
         runUntil(sender, () -> txnManager.isPartitionAdded(tp));
         assertFalse(txnManager.hasInFlightRequest());
     }
@@ -2892,7 +2893,7 @@ public class SenderTest {
 
             txnManager.beginTransaction();
             txnManager.maybeAddPartition(tp);
-            client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp, Errors.NONE)));
+            client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
             sender.initiateClose();
             AssertEndTxnRequestMatcher endTxnMatcher = new 
AssertEndTxnRequestMatcher(TransactionResult.ABORT);
@@ -2926,7 +2927,7 @@ public class SenderTest {
 
             txnManager.beginTransaction();
             txnManager.maybeAddPartition(tp);
-            client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp, Errors.NONE)));
+            client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp, Errors.NONE)));
             sender.runOnce();
 
             // Try to commit the transaction but it won't happen as we'll 
forcefully close the sender
@@ -2951,7 +2952,7 @@ public class SenderTest {
         // Begin the transaction
         txnManager.beginTransaction();
         txnManager.maybeAddPartition(tp0);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp0, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp0, Errors.NONE)));
         // Run it once so that the partition is added to the transaction.
         sender.runOnce();
         // Append a record to the accumulator.
@@ -2989,7 +2990,7 @@ public class SenderTest {
 
         txnManager.beginTransaction();
         txnManager.maybeAddPartition(tp0);
-        client.prepareResponse(new AddPartitionsToTxnResponse(0, 
Collections.singletonMap(tp0, Errors.NONE)));
+        client.prepareResponse(buildAddPartitionsToTxnResponseData(0, 
Collections.singletonMap(tp0, Errors.NONE)));
         sender.runOnce();
 
         // create a producer batch with more than one record so it is eligible 
for splitting
@@ -3338,4 +3339,11 @@ public class SenderTest {
         assertTrue(transactionManager.hasProducerId());
         assertEquals(producerIdAndEpoch, 
transactionManager.producerIdAndEpoch());
     }
+    
+    private AddPartitionsToTxnResponse buildAddPartitionsToTxnResponseData(int 
throttleMs, Map<TopicPartition, Errors> errors) {
+        AddPartitionsToTxnResponseData.AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(
+                AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID, errors);
+        AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData().setResultsByTopicV3AndBelow(result.topicResults()).setThrottleTimeMs(throttleMs);
+        return new AddPartitionsToTxnResponse(data);
+    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
index ce9b8052207..06bad272205 100644
--- 
a/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/clients/producer/internals/TransactionManagerTest.java
@@ -40,6 +40,8 @@ import 
org.apache.kafka.common.errors.UnsupportedVersionException;
 import org.apache.kafka.common.header.Header;
 import org.apache.kafka.common.internals.ClusterResourceListeners;
 import org.apache.kafka.common.message.AddOffsetsToTxnResponseData;
+import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult;
 import org.apache.kafka.common.message.ApiVersionsResponseData.ApiVersion;
 import org.apache.kafka.common.message.EndTxnResponseData;
 import org.apache.kafka.common.message.InitProducerIdResponseData;
@@ -1303,11 +1305,13 @@ public class TransactionManagerTest {
         Map<TopicPartition, Errors> errors = new HashMap<>();
         errors.put(tp0, Errors.TOPIC_AUTHORIZATION_FAILED);
         errors.put(tp1, Errors.OPERATION_NOT_ATTEMPTED);
+        AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID,
 errors);
+        AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData().setResultsByTopicV3AndBelow(result.topicResults()).setThrottleTimeMs(0);
         client.respond(body -> {
             AddPartitionsToTxnRequest request = (AddPartitionsToTxnRequest) 
body;
-            assertEquals(new HashSet<>(request.partitions()), new 
HashSet<>(errors.keySet()));
+            assertEquals(new HashSet<>(getPartitionsFromV3Request(request)), 
new HashSet<>(errors.keySet()));
             return true;
-        }, new AddPartitionsToTxnResponse(0, errors));
+        }, new AddPartitionsToTxnResponse(data));
 
         sender.runOnce();
         assertTrue(transactionManager.hasError());
@@ -3439,11 +3443,13 @@ public class TransactionManagerTest {
     }
 
     private void prepareAddPartitionsToTxn(final Map<TopicPartition, Errors> 
errors) {
+        AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID,
 errors);
+        AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData().setResultsByTopicV3AndBelow(result.topicResults()).setThrottleTimeMs(0);
         client.prepareResponse(body -> {
             AddPartitionsToTxnRequest request = (AddPartitionsToTxnRequest) 
body;
-            assertEquals(new HashSet<>(request.partitions()), new 
HashSet<>(errors.keySet()));
+            assertEquals(new HashSet<>(getPartitionsFromV3Request(request)), 
new HashSet<>(errors.keySet()));
             return true;
-        }, new AddPartitionsToTxnResponse(0, errors));
+        }, new AddPartitionsToTxnResponse(data));
     }
 
     private void prepareAddPartitionsToTxn(final TopicPartition tp, final 
Errors error) {
@@ -3522,27 +3528,39 @@ public class TransactionManagerTest {
 
     private void prepareAddPartitionsToTxnResponse(Errors error, final 
TopicPartition topicPartition,
                                                    final short epoch, final 
long producerId) {
+        AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(
+                AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID, 
singletonMap(topicPartition, error));
         client.prepareResponse(addPartitionsRequestMatcher(topicPartition, 
epoch, producerId),
-                new AddPartitionsToTxnResponse(0, singletonMap(topicPartition, 
error)));
+                new AddPartitionsToTxnResponse(new 
AddPartitionsToTxnResponseData()
+                        .setThrottleTimeMs(0)
+                        .setResultsByTopicV3AndBelow(result.topicResults())));
     }
 
     private void sendAddPartitionsToTxnResponse(Errors error, final 
TopicPartition topicPartition,
                                                 final short epoch, final long 
producerId) {
+        AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(
+                AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID, 
singletonMap(topicPartition, error));
         client.respond(addPartitionsRequestMatcher(topicPartition, epoch, 
producerId),
-                new AddPartitionsToTxnResponse(0, singletonMap(topicPartition, 
error)));
+                new AddPartitionsToTxnResponse(new 
AddPartitionsToTxnResponseData()
+                        .setThrottleTimeMs(0)
+                        .setResultsByTopicV3AndBelow(result.topicResults())));
     }
 
     private MockClient.RequestMatcher addPartitionsRequestMatcher(final 
TopicPartition topicPartition,
                                                                   final short 
epoch, final long producerId) {
         return body -> {
             AddPartitionsToTxnRequest addPartitionsToTxnRequest = 
(AddPartitionsToTxnRequest) body;
-            assertEquals(producerId, 
addPartitionsToTxnRequest.data().producerId());
-            assertEquals(epoch, 
addPartitionsToTxnRequest.data().producerEpoch());
-            assertEquals(singletonList(topicPartition), 
addPartitionsToTxnRequest.partitions());
-            assertEquals(transactionalId, 
addPartitionsToTxnRequest.data().transactionalId());
+            assertEquals(producerId, 
addPartitionsToTxnRequest.data().v3AndBelowProducerId());
+            assertEquals(epoch, 
addPartitionsToTxnRequest.data().v3AndBelowProducerEpoch());
+            assertEquals(singletonList(topicPartition), 
getPartitionsFromV3Request(addPartitionsToTxnRequest));
+            assertEquals(transactionalId, 
addPartitionsToTxnRequest.data().v3AndBelowTransactionalId());
             return true;
         };
     }
+    
+    private List<TopicPartition> 
getPartitionsFromV3Request(AddPartitionsToTxnRequest request) {
+        return 
AddPartitionsToTxnRequest.getPartitions(request.data().v3AndBelowTopics());
+    }
 
     private void prepareEndTxnResponse(Errors error, final TransactionResult 
result, final long producerId, final short epoch) {
         this.prepareEndTxnResponse(error, result, producerId, epoch, false);
diff --git 
a/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java 
b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java
index 3fcd0071c49..5762e84f60b 100644
--- a/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/message/MessageTest.java
@@ -24,6 +24,7 @@ import org.apache.kafka.common.Uuid;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransactionCollection;
 import 
org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBroker;
 import 
org.apache.kafka.common.message.DescribeClusterResponseData.DescribeClusterBrokerCollection;
 import 
org.apache.kafka.common.message.DescribeGroupsResponseData.DescribedGroup;
@@ -96,14 +97,26 @@ public final class MessageTest {
 
     @Test
     public void testAddPartitionsToTxnVersions() throws Exception {
-        testAllMessageRoundTrips(new AddPartitionsToTxnRequestData().
-                setTransactionalId("blah").
-                setProducerId(0xbadcafebadcafeL).
-                setProducerEpoch((short) 30000).
-                setTopics(new AddPartitionsToTxnTopicCollection(singletonList(
+        AddPartitionsToTxnRequestData v3AndBelowData = new 
AddPartitionsToTxnRequestData().
+                setV3AndBelowTransactionalId("blah").
+                setV3AndBelowProducerId(0xbadcafebadcafeL).
+                setV3AndBelowProducerEpoch((short) 30000).
+                setV3AndBelowTopics(new 
AddPartitionsToTxnTopicCollection(singletonList(
                         new AddPartitionsToTxnTopic().
                                 setName("Topic").
-                                setPartitions(singletonList(1))).iterator())));
+                                setPartitions(singletonList(1))).iterator()));
+        testDuplication(v3AndBelowData);
+        testAllMessageRoundTripsUntilVersion((short) 3, v3AndBelowData);
+
+        AddPartitionsToTxnRequestData data = new 
AddPartitionsToTxnRequestData().
+                setTransactions(new 
AddPartitionsToTxnTransactionCollection(singletonList(
+                       new 
AddPartitionsToTxnRequestData.AddPartitionsToTxnTransaction().
+                              setTransactionalId("blah").
+                              setProducerId(0xbadcafebadcafeL).
+                              setProducerEpoch((short) 30000).
+                              
setTopics(v3AndBelowData.v3AndBelowTopics())).iterator()));
+        testDuplication(data);
+        testAllMessageRoundTripsFromVersion((short) 4, data);
     }
 
     @Test
@@ -1032,6 +1045,12 @@ public final class MessageTest {
         }
     }
 
+    private void testAllMessageRoundTripsUntilVersion(short untilVersion, 
Message message) throws Exception {
+        for (short version = message.lowestSupportedVersion(); version <= 
untilVersion; version++) {
+            testEquivalentMessageRoundTrip(version, message);
+        }
+    }
+
     private void testMessageRoundTrip(short version, Message message, Message 
expected) throws Exception {
         testByteBufferRoundTrip(version, message, expected);
     }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java
 
b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java
index 04bde4ae61b..92bb8741be0 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequestTest.java
@@ -17,43 +17,145 @@
 package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.TopicPartition;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransaction;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransactionCollection;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection;
+import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
 import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
 
 import java.util.ArrayList;
-
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
+
+import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 
+import static 
org.apache.kafka.common.requests.AddPartitionsToTxnResponse.errorsForTransaction;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 
 public class AddPartitionsToTxnRequestTest {
-
-    private static String transactionalId = "transactionalId";
+    private final String transactionalId1 = "transaction1";
+    private final String transactionalId2 = "transaction2";
     private static int producerId = 10;
     private static short producerEpoch = 1;
     private static int throttleTimeMs = 10;
+    private static TopicPartition tp0 = new TopicPartition("topic", 0);
+    private static TopicPartition tp1 = new TopicPartition("topic", 1);
 
     @ParameterizedTest
     @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
     public void testConstructor(short version) {
-        List<TopicPartition> partitions = new ArrayList<>();
-        partitions.add(new TopicPartition("topic", 0));
-        partitions.add(new TopicPartition("topic", 1));
+        
+        AddPartitionsToTxnRequest request;
 
-        AddPartitionsToTxnRequest.Builder builder = new 
AddPartitionsToTxnRequest.Builder(transactionalId, producerId, producerEpoch, 
partitions);
-        AddPartitionsToTxnRequest request = builder.build(version);
+        if (version < 4) {
+            List<TopicPartition> partitions = new ArrayList<>();
+            partitions.add(tp0);
+            partitions.add(tp1);
 
-        assertEquals(transactionalId, request.data().transactionalId());
-        assertEquals(producerId, request.data().producerId());
-        assertEquals(producerEpoch, request.data().producerEpoch());
-        assertEquals(partitions, request.partitions());
+            AddPartitionsToTxnRequest.Builder builder = 
AddPartitionsToTxnRequest.Builder.forClient(transactionalId1, producerId, 
producerEpoch, partitions);
+            request = builder.build(version);
 
+            assertEquals(transactionalId1, 
request.data().v3AndBelowTransactionalId());
+            assertEquals(producerId, request.data().v3AndBelowProducerId());
+            assertEquals(producerEpoch, 
request.data().v3AndBelowProducerEpoch());
+            assertEquals(partitions, 
AddPartitionsToTxnRequest.getPartitions(request.data().v3AndBelowTopics()));
+        } else {
+            AddPartitionsToTxnTransactionCollection transactions = 
createTwoTransactionCollection();
+
+            AddPartitionsToTxnRequest.Builder builder = 
AddPartitionsToTxnRequest.Builder.forBroker(transactions);
+            request = builder.build(version);
+            
+            AddPartitionsToTxnTransaction reqTxn1 = 
request.data().transactions().find(transactionalId1);
+            AddPartitionsToTxnTransaction reqTxn2 = 
request.data().transactions().find(transactionalId2);
+
+            assertEquals(transactions.find(transactionalId1), reqTxn1);
+            assertEquals(transactions.find(transactionalId2), reqTxn2);
+        }
         AddPartitionsToTxnResponse response = 
request.getErrorResponse(throttleTimeMs, 
Errors.UNKNOWN_TOPIC_OR_PARTITION.exception());
 
-        
assertEquals(Collections.singletonMap(Errors.UNKNOWN_TOPIC_OR_PARTITION, 2), 
response.errorCounts());
         assertEquals(throttleTimeMs, response.throttleTimeMs());
+        
+        if (version >= 4) {
+            assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION.code(), 
response.data().errorCode());
+            // Since the error is top level, we count it as one error in the 
counts.
+            
assertEquals(Collections.singletonMap(Errors.UNKNOWN_TOPIC_OR_PARTITION, 1), 
response.errorCounts());
+        } else {
+            
assertEquals(Collections.singletonMap(Errors.UNKNOWN_TOPIC_OR_PARTITION, 2), 
response.errorCounts());  
+        }
+    }
+    
+    @Test
+    public void testBatchedRequests() {
+        AddPartitionsToTxnTransactionCollection transactions = 
createTwoTransactionCollection();
+
+        AddPartitionsToTxnRequest.Builder builder = 
AddPartitionsToTxnRequest.Builder.forBroker(transactions);
+        AddPartitionsToTxnRequest request = 
builder.build(ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion());
+        
+        Map<String, List<TopicPartition>> expectedMap = new HashMap<>();
+        expectedMap.put(transactionalId1, Collections.singletonList(tp0));
+        expectedMap.put(transactionalId2, Collections.singletonList(tp1));
+        
+        assertEquals(expectedMap, request.partitionsByTransaction());
+
+        AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection 
results = new 
AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection();
+        
+        results.add(request.errorResponseForTransaction(transactionalId1, 
Errors.UNKNOWN_TOPIC_OR_PARTITION));
+        results.add(request.errorResponseForTransaction(transactionalId2, 
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED));
+        
+        AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(new AddPartitionsToTxnResponseData()
+                .setResultsByTransaction(results)
+                .setThrottleTimeMs(throttleTimeMs));
+        
+        assertEquals(Collections.singletonMap(tp0, 
Errors.UNKNOWN_TOPIC_OR_PARTITION), 
errorsForTransaction(response.getTransactionTopicResults(transactionalId1)));
+        assertEquals(Collections.singletonMap(tp1, 
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED), 
errorsForTransaction(response.getTransactionTopicResults(transactionalId2)));
+    }
+    
+    @Test
+    public void testNormalizeRequest() {
+        List<TopicPartition> partitions = new ArrayList<>();
+        partitions.add(tp0);
+        partitions.add(tp1);
+
+        AddPartitionsToTxnRequest.Builder builder = 
AddPartitionsToTxnRequest.Builder.forClient(transactionalId1, producerId, 
producerEpoch, partitions);
+        AddPartitionsToTxnRequest request = builder.build((short) 3);
+
+        AddPartitionsToTxnRequest singleton = request.normalizeRequest();
+        assertEquals(partitions, 
singleton.partitionsByTransaction().get(transactionalId1));
+        
+        AddPartitionsToTxnTransaction transaction = 
singleton.data().transactions().find(transactionalId1);
+        assertEquals(producerId, transaction.producerId());
+        assertEquals(producerEpoch, transaction.producerEpoch());
+    }
+    
+    private AddPartitionsToTxnTransactionCollection 
createTwoTransactionCollection() {
+        AddPartitionsToTxnTopicCollection topics0 = new 
AddPartitionsToTxnTopicCollection();
+        topics0.add(new AddPartitionsToTxnTopic()
+                .setName(tp0.topic())
+                .setPartitions(Collections.singletonList(tp0.partition())));
+        AddPartitionsToTxnTopicCollection topics1 = new 
AddPartitionsToTxnTopicCollection();
+        topics1.add(new AddPartitionsToTxnTopic()
+                .setName(tp1.topic())
+                .setPartitions(Collections.singletonList(tp1.partition())));
+
+        AddPartitionsToTxnTransactionCollection transactions = new 
AddPartitionsToTxnTransactionCollection();
+        transactions.add(new AddPartitionsToTxnTransaction()
+                .setTransactionalId(transactionalId1)
+                .setProducerId(producerId)
+                .setProducerEpoch(producerEpoch)
+                .setVerifyOnly(true)
+                .setTopics(topics0));
+        transactions.add(new AddPartitionsToTxnTransaction()
+                .setTransactionalId(transactionalId2)
+                .setProducerId(producerId + 1)
+                .setProducerEpoch((short) (producerEpoch + 1))
+                .setVerifyOnly(false)
+                .setTopics(topics1));
+        return transactions;
     }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java
 
b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java
index 5b67bd47a01..3b2dbee332c 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java
@@ -18,18 +18,25 @@ package org.apache.kafka.common.requests;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnPartitionResult;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResult;
 import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnTopicResultCollection;
 import org.apache.kafka.common.protocol.ApiKeys;
 import org.apache.kafka.common.protocol.Errors;
+import org.apache.kafka.common.utils.annotation.ApiKeyVersionsSource;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
 
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
+import static 
org.apache.kafka.common.requests.AddPartitionsToTxnResponse.errorsForTransaction;
 import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 
 public class AddPartitionsToTxnResponseTest {
 
@@ -41,9 +48,9 @@ public class AddPartitionsToTxnResponseTest {
     protected final Errors errorTwo = Errors.NOT_COORDINATOR;
     protected final String topicTwo = "topic2";
     protected final int partitionTwo = 2;
+    protected final TopicPartition tp1 = new TopicPartition(topicOne, 
partitionOne);
+    protected final TopicPartition tp2 = new TopicPartition(topicTwo, 
partitionTwo);
 
-    protected TopicPartition tp1 = new TopicPartition(topicOne, partitionOne);
-    protected TopicPartition tp2 = new TopicPartition(topicTwo, partitionTwo);
     protected Map<Errors, Integer> expectedErrorCounts;
     protected Map<TopicPartition, Errors> errorsMap;
 
@@ -58,42 +65,80 @@ public class AddPartitionsToTxnResponseTest {
         errorsMap.put(tp2, errorTwo);
     }
 
-    @Test
-    public void testConstructorWithErrorResponse() {
-        AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(throttleTimeMs, errorsMap);
-
-        assertEquals(expectedErrorCounts, response.errorCounts());
-        assertEquals(throttleTimeMs, response.throttleTimeMs());
-    }
-
-    @Test
-    public void testParse() {
-
+    @ParameterizedTest
+    @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
+    public void testParse(short version) {
         AddPartitionsToTxnTopicResultCollection topicCollection = new 
AddPartitionsToTxnTopicResultCollection();
 
         AddPartitionsToTxnTopicResult topicResult = new 
AddPartitionsToTxnTopicResult();
         topicResult.setName(topicOne);
 
-        topicResult.results().add(new AddPartitionsToTxnPartitionResult()
-                                      .setErrorCode(errorOne.code())
+        topicResult.resultsByPartition().add(new 
AddPartitionsToTxnPartitionResult()
+                                      .setPartitionErrorCode(errorOne.code())
                                       .setPartitionIndex(partitionOne));
 
-        topicResult.results().add(new AddPartitionsToTxnPartitionResult()
-                                      .setErrorCode(errorTwo.code())
+        topicResult.resultsByPartition().add(new 
AddPartitionsToTxnPartitionResult()
+                                      .setPartitionErrorCode(errorTwo.code())
                                       .setPartitionIndex(partitionTwo));
 
         topicCollection.add(topicResult);
+            
+        if (version < 4) {
+            AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData()
+                    .setResultsByTopicV3AndBelow(topicCollection)
+                    .setThrottleTimeMs(throttleTimeMs);
+            AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(data);
 
-        AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData()
-                                                  .setResults(topicCollection)
-                                                  
.setThrottleTimeMs(throttleTimeMs);
-        AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(data);
-
-        for (short version : ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions()) {
             AddPartitionsToTxnResponse parsedResponse = 
AddPartitionsToTxnResponse.parse(response.serialize(version), version);
             assertEquals(expectedErrorCounts, parsedResponse.errorCounts());
             assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs());
             assertEquals(version >= 1, 
parsedResponse.shouldClientThrottle(version));
+        } else {
+            AddPartitionsToTxnResultCollection results = new 
AddPartitionsToTxnResultCollection();
+            results.add(new 
AddPartitionsToTxnResult().setTransactionalId("txn1").setTopicResults(topicCollection));
+            
+            // Create another transaction with new name and errorOne for a 
single partition.
+            Map<TopicPartition, Errors> txnTwoExpectedErrors = 
Collections.singletonMap(tp2, errorOne);
+            
results.add(AddPartitionsToTxnResponse.resultForTransaction("txn2", 
txnTwoExpectedErrors));
+
+            AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData()
+                    .setResultsByTransaction(results)
+                    .setThrottleTimeMs(throttleTimeMs);
+            AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(data);
+
+            Map<Errors, Integer> newExpectedErrorCounts = new HashMap<>();
+            newExpectedErrorCounts.put(Errors.NONE, 1); // top level error
+            newExpectedErrorCounts.put(errorOne, 2);
+            newExpectedErrorCounts.put(errorTwo, 1);
+            
+            AddPartitionsToTxnResponse parsedResponse = 
AddPartitionsToTxnResponse.parse(response.serialize(version), version);
+            assertEquals(txnTwoExpectedErrors, 
errorsForTransaction(response.getTransactionTopicResults("txn2")));
+            assertEquals(newExpectedErrorCounts, parsedResponse.errorCounts());
+            assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs());
+            assertTrue(parsedResponse.shouldClientThrottle(version));
         }
     }
+    
+    @Test
+    public void testBatchedErrors() {
+        Map<TopicPartition, Errors> txn1Errors = Collections.singletonMap(tp1, 
errorOne);
+        Map<TopicPartition, Errors> txn2Errors = Collections.singletonMap(tp1, 
errorOne);
+        
+        AddPartitionsToTxnResult transaction1 = 
AddPartitionsToTxnResponse.resultForTransaction("txn1", txn1Errors);
+        AddPartitionsToTxnResult transaction2 = 
AddPartitionsToTxnResponse.resultForTransaction("txn2", txn2Errors);
+        
+        AddPartitionsToTxnResultCollection results = new 
AddPartitionsToTxnResultCollection();
+        results.add(transaction1);
+        results.add(transaction2);
+        
+        AddPartitionsToTxnResponse response = new 
AddPartitionsToTxnResponse(new 
AddPartitionsToTxnResponseData().setResultsByTransaction(results));
+        
+        assertEquals(txn1Errors, 
errorsForTransaction(response.getTransactionTopicResults("txn1")));
+        assertEquals(txn2Errors, 
errorsForTransaction(response.getTransactionTopicResults("txn2")));
+        
+        Map<String, Map<TopicPartition, Errors>> expectedErrors = new 
HashMap<>();
+        expectedErrors.put("txn1", txn1Errors);
+        expectedErrors.put("txn2", txn2Errors);
+        assertEquals(expectedErrors, response.errors());
+    }
 }
diff --git 
a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
 
b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
index a1b399a2d35..7b0ca0d233e 100644
--- 
a/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
+++ 
b/clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java
@@ -35,8 +35,13 @@ import 
org.apache.kafka.common.errors.NotEnoughReplicasException;
 import org.apache.kafka.common.errors.SecurityDisabledException;
 import org.apache.kafka.common.errors.UnknownServerException;
 import org.apache.kafka.common.errors.UnsupportedVersionException;
+import org.apache.kafka.common.message.AddPartitionsToTxnResponseData;
 import org.apache.kafka.common.message.AddOffsetsToTxnRequestData;
 import org.apache.kafka.common.message.AddOffsetsToTxnResponseData;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransaction;
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransactionCollection;
 import org.apache.kafka.common.message.AllocateProducerIdsRequestData;
 import org.apache.kafka.common.message.AllocateProducerIdsResponseData;
 import org.apache.kafka.common.message.AlterClientQuotasResponseData;
@@ -926,7 +931,8 @@ public class RequestResponseTest {
     @Test
     public void testErrorCountsIncludesNone() {
         assertEquals(1, 
createAddOffsetsToTxnResponse().errorCounts().get(Errors.NONE));
-        assertEquals(1, 
createAddPartitionsToTxnResponse().errorCounts().get(Errors.NONE));
+        assertEquals(1, createAddPartitionsToTxnResponse((short) 
3).errorCounts().get(Errors.NONE));
+        assertEquals(2, createAddPartitionsToTxnResponse((short) 
4).errorCounts().get(Errors.NONE));
         assertEquals(1, 
createAlterClientQuotasResponse().errorCounts().get(Errors.NONE));
         assertEquals(1, 
createAlterConfigsResponse().errorCounts().get(Errors.NONE));
         assertEquals(2, 
createAlterPartitionReassignmentsResponse().errorCounts().get(Errors.NONE));
@@ -1080,7 +1086,7 @@ public class RequestResponseTest {
             case DELETE_RECORDS: return createDeleteRecordsResponse();
             case INIT_PRODUCER_ID: return createInitPidResponse();
             case OFFSET_FOR_LEADER_EPOCH: return createLeaderEpochResponse();
-            case ADD_PARTITIONS_TO_TXN: return 
createAddPartitionsToTxnResponse();
+            case ADD_PARTITIONS_TO_TXN: return 
createAddPartitionsToTxnResponse(version);
             case ADD_OFFSETS_TO_TXN: return createAddOffsetsToTxnResponse();
             case END_TXN: return createEndTxnResponse();
             case WRITE_TXN_MARKERS: return createWriteTxnMarkersResponse();
@@ -1611,7 +1617,7 @@ public class RequestResponseTest {
             serializedBytes.rewind();
             assertEquals(serializedBytes, serializedBytes2, "Response " + 
response + "failed equality test");
         } catch (Exception e) {
-            throw new RuntimeException("Failed to deserialize response " + 
response + " with type " + response.getClass(), e);
+            throw new RuntimeException("Failed to deserialize version " + 
version + " response " + response + " with type " + response.getClass(), e);
         }
     }
 
@@ -2598,12 +2604,37 @@ public class RequestResponseTest {
     }
 
     private AddPartitionsToTxnRequest createAddPartitionsToTxnRequest(short 
version) {
-        return new AddPartitionsToTxnRequest.Builder("tid", 21L, (short) 42,
-            singletonList(new TopicPartition("topic", 73))).build(version);
+        if (version < 4) {
+            return AddPartitionsToTxnRequest.Builder.forClient("tid", 21L, 
(short) 42,
+                    singletonList(new TopicPartition("topic", 
73))).build(version);
+        } else {
+            AddPartitionsToTxnTransactionCollection transactions = new 
AddPartitionsToTxnTransactionCollection(
+                singletonList(new AddPartitionsToTxnTransaction()
+                    .setTransactionalId("tid")
+                    .setProducerId(21L)
+                    .setProducerEpoch((short) 42)
+                    .setVerifyOnly(false)
+                    .setTopics(new AddPartitionsToTxnTopicCollection(
+                        singletonList(new AddPartitionsToTxnTopic()
+                            .setName("topic")
+                            
.setPartitions(Collections.singletonList(73))).iterator())))
+                    .iterator());
+            return 
AddPartitionsToTxnRequest.Builder.forBroker(transactions).build(version);  
+        }
     }
 
-    private AddPartitionsToTxnResponse createAddPartitionsToTxnResponse() {
-        return new AddPartitionsToTxnResponse(0, Collections.singletonMap(new 
TopicPartition("t", 0), Errors.NONE));
+    private AddPartitionsToTxnResponse createAddPartitionsToTxnResponse(short 
version) {
+        String txnId = version < 4 ? 
AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID : "tid";
+        AddPartitionsToTxnResponseData.AddPartitionsToTxnResult result = 
AddPartitionsToTxnResponse.resultForTransaction(
+                txnId, Collections.singletonMap(new TopicPartition("t", 0), 
Errors.NONE));
+        AddPartitionsToTxnResponseData data = new 
AddPartitionsToTxnResponseData().setThrottleTimeMs(0);
+        
+        if (version < 4) {
+            data.setResultsByTopicV3AndBelow(result.topicResults());
+        } else {
+            data.setResultsByTransaction(new 
AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection(singletonList(result).iterator()));
+        }
+        return new AddPartitionsToTxnResponse(data);
     }
 
     private AddOffsetsToTxnRequest createAddOffsetsToTxnRequest(short version) 
{
diff --git 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
index 1ec906cd223..02142f938a8 100644
--- 
a/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
+++ 
b/core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala
@@ -22,14 +22,18 @@ import kafka.server.{KafkaConfig, MetadataCache, 
ReplicaManager, RequestLocal}
 import kafka.utils.Logging
 import org.apache.kafka.common.TopicPartition
 import org.apache.kafka.common.internals.Topic
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult
 import org.apache.kafka.common.message.{DescribeTransactionsResponseData, 
ListTransactionsResponseData}
 import org.apache.kafka.common.metrics.Metrics
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.RecordBatch
-import org.apache.kafka.common.requests.TransactionResult
+import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, 
TransactionResult}
 import org.apache.kafka.common.utils.{LogContext, ProducerIdAndEpoch, Time}
 import org.apache.kafka.server.util.Scheduler
 
+import scala.collection.mutable
+import scala.jdk.CollectionConverters._
+
 object TransactionCoordinator {
 
   def apply(config: KafkaConfig,
@@ -92,6 +96,7 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
 
   type InitProducerIdCallback = InitProducerIdResult => Unit
   type AddPartitionsCallback = Errors => Unit
+  type VerifyPartitionsCallback = AddPartitionsToTxnResult => Unit
   type EndTxnCallback = Errors => Unit
   type ApiResult[T] = Either[Errors, T]
 
@@ -317,6 +322,35 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
       }
     }
   }
+  
+  def handleVerifyPartitionsInTransaction(transactionalId: String,
+                                          producerId: Long,
+                                          producerEpoch: Short,
+                                          partitions: 
collection.Set[TopicPartition],
+                                          responseCallback: 
VerifyPartitionsCallback): Unit = {
+    if (transactionalId == null || transactionalId.isEmpty) {
+      debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for 
$transactionalId's AddPartitions request for verification")
+      
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
 partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava))
+    } else {
+      val result: ApiResult[(Int, TransactionMetadata)] = 
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
+      
+      result match {
+        case Left(err) =>
+          debug(s"Returning $err error code to client for $transactionalId's 
AddPartitions request for verification")
+          
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
 partitions.map(_ -> err).toMap.asJava))
+
+        case Right((_, txnMetadata)) =>
+          val errors = mutable.Map[TopicPartition, Errors]()
+          partitions.foreach { tp => 
+            if (txnMetadata.topicPartitions.contains(tp))
+              errors.put(tp, Errors.NONE)
+            else
+              errors.put(tp, Errors.INVALID_TXN_STATE)  
+          }
+          
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
 errors.asJava))
+      }
+    }
+  }
 
   def handleAddPartitionsToTransaction(transactionalId: String,
                                        producerId: Long,
@@ -330,44 +364,51 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
     } else {
       // try to update the transaction metadata and append the updated 
metadata to txn log;
       // if there is no such metadata treat it as invalid producerId mapping 
error.
-      val result: ApiResult[(Int, TxnTransitMetadata)] = 
txnManager.getTransactionState(transactionalId).flatMap {
-        case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING)
-
-        case Some(epochAndMetadata) =>
-          val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
-          val txnMetadata = epochAndMetadata.transactionMetadata
-
-          // generate the new transaction metadata with added partitions
-          txnMetadata.inLock {
-            if (txnMetadata.producerId != producerId) {
-              Left(Errors.INVALID_PRODUCER_ID_MAPPING)
-            } else if (txnMetadata.producerEpoch != producerEpoch) {
-              Left(Errors.PRODUCER_FENCED)
-            } else if (txnMetadata.pendingTransitionInProgress) {
-              // return a retriable exception to let the client backoff and 
retry
-              Left(Errors.CONCURRENT_TRANSACTIONS)
-            } else if (txnMetadata.state == PrepareCommit || txnMetadata.state 
== PrepareAbort) {
-              Left(Errors.CONCURRENT_TRANSACTIONS)
-            } else if (txnMetadata.state == Ongoing && 
partitions.subsetOf(txnMetadata.topicPartitions)) {
-              // this is an optimization: if the partitions are already in the 
metadata reply OK immediately
-              Left(Errors.NONE)
-            } else {
-              Right(coordinatorEpoch, 
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()))
-            }
-          }
-      }
+      val result: ApiResult[(Int, TransactionMetadata)] = 
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
 
       result match {
         case Left(err) =>
           debug(s"Returning $err error code to client for $transactionalId's 
AddPartitions request")
           responseCallback(err)
 
-        case Right((coordinatorEpoch, newMetadata)) =>
-          txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, 
newMetadata,
+        case Right((coordinatorEpoch, txnMetadata)) =>
+          txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch, 
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()),
             responseCallback, requestLocal = requestLocal)
       }
     }
   }
+  
+  private def getTransactionMetadata(transactionalId: String,
+                                     producerId: Long,
+                                     producerEpoch: Short,
+                                     partitions: 
collection.Set[TopicPartition]): ApiResult[(Int, TransactionMetadata)] = {
+    txnManager.getTransactionState(transactionalId).flatMap {
+      case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING)
+
+      case Some(epochAndMetadata) =>
+        val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
+        val txnMetadata = epochAndMetadata.transactionMetadata
+
+        // generate the new transaction metadata with added partitions
+        txnMetadata.inLock {
+          if (txnMetadata.producerId != producerId) {
+            Left(Errors.INVALID_PRODUCER_ID_MAPPING)
+          } else if (txnMetadata.producerEpoch != producerEpoch) {
+            Left(Errors.PRODUCER_FENCED)
+          } else if (txnMetadata.pendingTransitionInProgress) {
+            // return a retriable exception to let the client backoff and retry
+            Left(Errors.CONCURRENT_TRANSACTIONS)
+          } else if (txnMetadata.state == PrepareCommit || txnMetadata.state 
== PrepareAbort) {
+            Left(Errors.CONCURRENT_TRANSACTIONS)
+          } else if (txnMetadata.state == Ongoing && 
partitions.subsetOf(txnMetadata.topicPartitions)) {
+            // this is an optimization: if the partitions are already in the 
metadata reply OK immediately
+            Left(Errors.NONE)
+          } else {
+            Right(coordinatorEpoch, txnMetadata)
+          }
+        }
+    }
+  }
 
   /**
    * Load state from the given partition and begin handling requests for 
groups which map to this partition.
diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index ad0f1d1b784..bde98b484f0 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -34,6 +34,8 @@ import org.apache.kafka.common.config.ConfigResource
 import org.apache.kafka.common.errors._
 import org.apache.kafka.common.internals.Topic.{GROUP_METADATA_TOPIC_NAME, 
TRANSACTION_STATE_TOPIC_NAME, isInternal}
 import org.apache.kafka.common.internals.{FatalExitError, Topic}
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResultCollection
 import 
org.apache.kafka.common.message.AlterConfigsResponseData.AlterConfigsResourceResponse
 import 
org.apache.kafka.common.message.AlterPartitionReassignmentsResponseData.{ReassignablePartitionResponse,
 ReassignableTopicResponse}
 import 
org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult
@@ -199,7 +201,7 @@ class KafkaApis(val requestChannel: RequestChannel,
         case ApiKeys.DELETE_RECORDS => handleDeleteRecordsRequest(request)
         case ApiKeys.INIT_PRODUCER_ID => handleInitProducerIdRequest(request, 
requestLocal)
         case ApiKeys.OFFSET_FOR_LEADER_EPOCH => 
handleOffsetForLeaderEpochRequest(request)
-        case ApiKeys.ADD_PARTITIONS_TO_TXN => 
handleAddPartitionToTxnRequest(request, requestLocal)
+        case ApiKeys.ADD_PARTITIONS_TO_TXN => 
handleAddPartitionsToTxnRequest(request, requestLocal)
         case ApiKeys.ADD_OFFSETS_TO_TXN => 
handleAddOffsetsToTxnRequest(request, requestLocal)
         case ApiKeys.END_TXN => handleEndTxnRequest(request, requestLocal)
         case ApiKeys.WRITE_TXN_MARKERS => 
handleWriteTxnMarkersRequest(request, requestLocal)
@@ -2386,66 +2388,111 @@ class KafkaApis(val requestChannel: RequestChannel,
     if (config.interBrokerProtocolVersion.isLessThan(version))
       throw new UnsupportedVersionException(s"inter.broker.protocol.version: 
${config.interBrokerProtocolVersion.version} is less than the required version: 
${version.version}")
   }
-
-  def handleAddPartitionToTxnRequest(request: RequestChannel.Request, 
requestLocal: RequestLocal): Unit = {
+  def handleAddPartitionsToTxnRequest(request: RequestChannel.Request, 
requestLocal: RequestLocal): Unit = {
     ensureInterBrokerVersion(IBP_0_11_0_IV0)
-    val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest]
-    val transactionalId = addPartitionsToTxnRequest.data.transactionalId
-    val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala
-    if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID, 
transactionalId))
-      requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
-        addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs, 
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception))
-    else {
-      val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
-      val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
-      val authorizedPartitions = mutable.Set[TopicPartition]()
-
-      val authorizedTopics = authHelper.filterByAuthorized(request.context, 
WRITE, TOPIC,
-        partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic)
-      for (topicPartition <- partitionsToAdd) {
-        if (!authorizedTopics.contains(topicPartition.topic))
-          unauthorizedTopicErrors += topicPartition -> 
Errors.TOPIC_AUTHORIZATION_FAILED
-        else if (!metadataCache.contains(topicPartition))
-          nonExistingTopicErrors += topicPartition -> 
Errors.UNKNOWN_TOPIC_OR_PARTITION
-        else
-          authorizedPartitions.add(topicPartition)
+    val addPartitionsToTxnRequest =
+      if (request.context.apiVersion() < 4) 
+        request.body[AddPartitionsToTxnRequest].normalizeRequest() 
+      else 
+        request.body[AddPartitionsToTxnRequest]
+    val version = addPartitionsToTxnRequest.version
+    val responses = new AddPartitionsToTxnResultCollection()
+    val partitionsByTransaction = 
addPartitionsToTxnRequest.partitionsByTransaction()
+    
+    // Newer versions of the request should only come from other brokers.
+    if (version >= 4) authHelper.authorizeClusterOperation(request, 
CLUSTER_ACTION)
+
+    // V4 requests introduced batches of transactions. We need all 
transactions to be handled before sending the 
+    // response so there are a few differences in handling errors and sending 
responses.
+    def createResponse(requestThrottleMs: Int): AbstractResponse = {
+      if (version < 4) {
+        // There will only be one response in data. Add it to the response 
data object.
+        val data = new AddPartitionsToTxnResponseData()
+        responses.forEach { result => 
+          data.setResultsByTopicV3AndBelow(result.topicResults())
+          data.setThrottleTimeMs(requestThrottleMs)
+        }
+        new AddPartitionsToTxnResponse(data)
+      } else {
+        new AddPartitionsToTxnResponse(new 
AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses))
       }
+    }
 
-      if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty) 
{
-        // Any failed partition check causes the entire request to fail. We 
send the appropriate error codes for the
-        // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error 
code for the partitions which succeeded
-        // the authorization check to indicate that they were not added to the 
transaction.
-        val partitionErrors = unauthorizedTopicErrors ++ 
nonExistingTopicErrors ++
-          authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
-        requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
-          new AddPartitionsToTxnResponse(requestThrottleMs, 
partitionErrors.asJava))
+    val txns = addPartitionsToTxnRequest.data.transactions
+    def addResultAndMaybeSendResponse(result: AddPartitionsToTxnResult): Unit 
= {
+      val canSend = responses.synchronized {
+        responses.add(result)
+        responses.size == txns.size
+      }
+      if (canSend) {
+        requestHelper.sendResponseMaybeThrottle(request, createResponse)
+      }
+    }
+
+    txns.forEach { transaction => 
+      val transactionalId = transaction.transactionalId
+      val partitionsToAdd = 
partitionsByTransaction.get(transactionalId).asScala
+
+      // Versions < 4 come from clients and must be authorized to write for 
the given transaction and for the given topics.
+      if (version < 4 && !authHelper.authorize(request.context, WRITE, 
TRANSACTIONAL_ID, transactionalId)) {
+        
addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId,
 Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED))
       } else {
-        def sendResponseCallback(error: Errors): Unit = {
-          def createResponse(requestThrottleMs: Int): AbstractResponse = {
-            val finalError =
-              if (addPartitionsToTxnRequest.version < 2 && error == 
Errors.PRODUCER_FENCED) {
+        val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
+        val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
+        val authorizedPartitions = mutable.Set[TopicPartition]()
+
+        // Only request versions less than 4 need write authorization since 
they come from clients.
+        val authorizedTopics = 
+          if (version < 4) 
+            authHelper.filterByAuthorized(request.context, WRITE, TOPIC, 
partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic) 
+          else 
+            partitionsToAdd.map(_.topic).toSet
+        for (topicPartition <- partitionsToAdd) {
+          if (!authorizedTopics.contains(topicPartition.topic))
+            unauthorizedTopicErrors += topicPartition -> 
Errors.TOPIC_AUTHORIZATION_FAILED
+          else if (!metadataCache.contains(topicPartition))
+            nonExistingTopicErrors += topicPartition -> 
Errors.UNKNOWN_TOPIC_OR_PARTITION
+          else
+            authorizedPartitions.add(topicPartition)
+        }
+
+        if (unauthorizedTopicErrors.nonEmpty || 
nonExistingTopicErrors.nonEmpty) {
+          // Any failed partition check causes the entire transaction to fail. 
We send the appropriate error codes for the
+          // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error 
code for the partitions which succeeded
+          // the authorization check to indicate that they were not added to 
the transaction.
+          val partitionErrors = unauthorizedTopicErrors ++ 
nonExistingTopicErrors ++
+            authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
+          
addResultAndMaybeSendResponse(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
 partitionErrors.asJava))
+        } else {
+          def sendResponseCallback(error: Errors): Unit = {
+            val finalError = {
+              if (version < 2 && error == Errors.PRODUCER_FENCED) {
                 // For older clients, they could not understand the new 
PRODUCER_FENCED error code,
                 // so we need to return the old INVALID_PRODUCER_EPOCH to have 
the same client handling logic.
                 Errors.INVALID_PRODUCER_EPOCH
               } else {
                 error
               }
-
-            val responseBody: AddPartitionsToTxnResponse = new 
AddPartitionsToTxnResponse(requestThrottleMs,
-              partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava)
-            trace(s"Completed $transactionalId's AddPartitionsToTxnRequest 
with partitions $partitionsToAdd: errors: $error from client 
${request.header.clientId}")
-            responseBody
+            }
+            
addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId,
 finalError))
           }
 
-          requestHelper.sendResponseMaybeThrottle(request, createResponse)
-        }
 
-        txnCoordinator.handleAddPartitionsToTransaction(transactionalId,
-          addPartitionsToTxnRequest.data.producerId,
-          addPartitionsToTxnRequest.data.producerEpoch,
-          authorizedPartitions,
-          sendResponseCallback,
-          requestLocal)
+          if (!transaction.verifyOnly) {
+            txnCoordinator.handleAddPartitionsToTransaction(transactionalId,
+              transaction.producerId,
+              transaction.producerEpoch,
+              authorizedPartitions,
+              sendResponseCallback,
+              requestLocal)
+          } else {
+            txnCoordinator.handleVerifyPartitionsInTransaction(transactionalId,
+              transaction.producerId,
+              transaction.producerEpoch,
+              authorizedPartitions,
+              addResultAndMaybeSendResponse)
+          }
+        }
       }
     }
   }
diff --git 
a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala 
b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
index 3b4f893eb55..e8f2ea88c49 100644
--- a/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
+++ b/core/src/test/scala/integration/kafka/api/AuthorizerIntegrationTest.scala
@@ -231,7 +231,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
       resp.errors.get(new ConfigResource(ConfigResource.Type.TOPIC, 
tp.topic)).error),
     ApiKeys.INIT_PRODUCER_ID -> ((resp: InitProducerIdResponse) => resp.error),
     ApiKeys.WRITE_TXN_MARKERS -> ((resp: WriteTxnMarkersResponse) => 
resp.errorsByProducerId.get(producerId).get(tp)),
-    ApiKeys.ADD_PARTITIONS_TO_TXN -> ((resp: AddPartitionsToTxnResponse) => 
resp.errors.get(tp)),
+    ApiKeys.ADD_PARTITIONS_TO_TXN -> ((resp: AddPartitionsToTxnResponse) => 
resp.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID).get(tp)),
     ApiKeys.ADD_OFFSETS_TO_TXN -> ((resp: AddOffsetsToTxnResponse) => 
Errors.forCode(resp.data.errorCode)),
     ApiKeys.END_TXN -> ((resp: EndTxnResponse) => resp.error),
     ApiKeys.TXN_OFFSET_COMMIT -> ((resp: TxnOffsetCommitResponse) => 
resp.errors.get(tp)),
@@ -672,7 +672,7 @@ class AuthorizerIntegrationTest extends BaseRequestTest {
   private def describeLogDirsRequest = new DescribeLogDirsRequest.Builder(new 
DescribeLogDirsRequestData().setTopics(new 
DescribeLogDirsRequestData.DescribableLogDirTopicCollection(Collections.singleton(
     new 
DescribeLogDirsRequestData.DescribableLogDirTopic().setTopic(tp.topic).setPartitions(Collections.singletonList(tp.partition))).iterator()))).build()
 
-  private def addPartitionsToTxnRequest = new 
AddPartitionsToTxnRequest.Builder(transactionalId, 1, 1, 
Collections.singletonList(tp)).build()
+  private def addPartitionsToTxnRequest = 
AddPartitionsToTxnRequest.Builder.forClient(transactionalId, 1, 1, 
Collections.singletonList(tp)).build()
 
   private def addOffsetsToTxnRequest = new AddOffsetsToTxnRequest.Builder(
     new AddOffsetsToTxnRequestData()
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
index 501355d862b..c458ac191c0 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala
@@ -498,7 +498,10 @@ class TransactionCoordinatorConcurrencyTest extends 
AbstractCoordinatorConcurren
 
   abstract class TxnOperation[R] extends Operation {
     @volatile var result: Option[R] = None
+    @volatile var results: Map[TopicPartition, R] = _
+
     def resultCallback(r: R): Unit = this.result = Some(r)
+    
   }
 
   class InitProducerIdOperation(val producerIdAndEpoch: 
Option[ProducerIdAndEpoch] = None) extends TxnOperation[InitProducerIdResult] {
diff --git 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
index 1c8e0fcdc1b..fc84244cf21 100644
--- 
a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
+++ 
b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala
@@ -17,9 +17,10 @@
 package kafka.coordinator.transaction
 
 import org.apache.kafka.common.TopicPartition
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult
 import org.apache.kafka.common.protocol.Errors
 import org.apache.kafka.common.record.RecordBatch
-import org.apache.kafka.common.requests.TransactionResult
+import org.apache.kafka.common.requests.{AddPartitionsToTxnResponse, 
TransactionResult}
 import org.apache.kafka.common.utils.{LogContext, MockTime, ProducerIdAndEpoch}
 import org.apache.kafka.server.util.MockScheduler
 import org.junit.jupiter.api.Assertions._
@@ -314,6 +315,41 @@ class TransactionCoordinatorTest {
     
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
   }
 
+  @Test
+  def 
shouldRespondWithErrorsNoneOnAddPartitionWhenOngoingVerifyOnlyAndPartitionsTheSame():
 Unit = {
+    var errors: Map[TopicPartition, Errors] = Map.empty
+    def verifyPartitionsInTxnCallback(result: AddPartitionsToTxnResult): Unit 
= {
+      errors = 
AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap
+    }
+    
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
+        new TransactionMetadata(transactionalId, 0, 0, 0, 
RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0)))))
+
+    coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, 
partitions, verifyPartitionsInTxnCallback)
+    assertEquals(Errors.NONE, error)
+    
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
+  }
+  
+  @Test
+  def shouldRespondWithInvalidTxnStateWhenVerifyOnlyAndPartitionNotPresent(): 
Unit = {
+    var errors: Map[TopicPartition, Errors] = Map.empty
+    def verifyPartitionsInTxnCallback(result: AddPartitionsToTxnResult): Unit 
= {
+      errors = 
AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap
+    }
+    
+    
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+      .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
+        new TransactionMetadata(transactionalId, 0, 0, 0, 
RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0)))))
+    
+    val extraPartitions = partitions ++ Set(new TopicPartition("topic2", 0))
+    
+    coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0, 
extraPartitions, verifyPartitionsInTxnCallback)
+    assertEquals(Errors.INVALID_TXN_STATE, errors(new TopicPartition("topic2", 
0)))
+    assertEquals(Errors.NONE, errors(new TopicPartition("topic1", 0)))
+    
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
+  }
+
   @Test
   def shouldReplyWithInvalidPidMappingOnEndTxnWhenTxnIdDoesntExist(): Unit = {
     
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
diff --git 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
index 74320e62b49..5673315cf31 100644
--- 
a/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
+++ 
b/core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala
@@ -17,25 +17,35 @@
 
 package kafka.server
 
-import kafka.utils.TestInfoUtils
+import kafka.utils.{TestInfoUtils, TestUtils}
 
-import java.util.Properties
+import java.util.{Collections, Properties}
+import java.util.stream.{Stream => JStream}
 import org.apache.kafka.common.TopicPartition
-import org.apache.kafka.common.protocol.Errors
-import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, 
AddPartitionsToTxnResponse}
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopic
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransaction
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTransactionCollection
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.AddPartitionsToTxnTopicCollection
+import org.apache.kafka.common.message.{FindCoordinatorRequestData, 
InitProducerIdRequestData}
+import org.apache.kafka.common.protocol.{ApiKeys, Errors}
+import org.apache.kafka.common.requests.FindCoordinatorRequest.CoordinatorType
+import org.apache.kafka.common.requests.{AddPartitionsToTxnRequest, 
AddPartitionsToTxnResponse, FindCoordinatorRequest, FindCoordinatorResponse, 
InitProducerIdRequest, InitProducerIdResponse}
 import org.junit.jupiter.api.Assertions._
-import org.junit.jupiter.api.{BeforeEach, TestInfo}
+import org.junit.jupiter.api.{BeforeEach, Test, TestInfo}
 import org.junit.jupiter.params.ParameterizedTest
-import org.junit.jupiter.params.provider.ValueSource
+import org.junit.jupiter.params.provider.{Arguments, MethodSource}
 
+import scala.collection.mutable
 import scala.jdk.CollectionConverters._
 
 class AddPartitionsToTxnRequestServerTest extends BaseRequestTest {
   private val topic1 = "topic1"
   val numPartitions = 1
 
-  override def brokerPropertyOverrides(properties: Properties): Unit =
+  override def brokerPropertyOverrides(properties: Properties): Unit = {
+    properties.put(KafkaConfig.UnstableApiVersionsEnableProp, "true")
     properties.put(KafkaConfig.AutoCreateTopicsEnableProp, false.toString)
+  }
 
   @BeforeEach
   override def setUp(testInfo: TestInfo): Unit = {
@@ -44,8 +54,8 @@ class AddPartitionsToTxnRequestServerTest extends 
BaseRequestTest {
   }
 
   @ParameterizedTest(name = TestInfoUtils.TestWithParameterizedQuorumName)
-  @ValueSource(strings = Array("zk", "kraft"))
-  def shouldReceiveOperationNotAttemptedWhenOtherPartitionHasError(quorum: 
String): Unit = {
+  @MethodSource(value = Array("parameters"))
+  def shouldReceiveOperationNotAttemptedWhenOtherPartitionHasError(quorum: 
String, version: Short): Unit = {
     // The basic idea is that we have one unknown topic and one created topic. 
We should get the 'UNKNOWN_TOPIC_OR_PARTITION'
     // error for the unknown topic and the 'OPERATION_NOT_ATTEMPTED' error for 
the known and authorized topic.
     val nonExistentTopic = new TopicPartition("unknownTopic", 0)
@@ -55,22 +65,146 @@ class AddPartitionsToTxnRequestServerTest extends 
BaseRequestTest {
     val producerId = 1000L
     val producerEpoch: Short = 0
 
-    val request = new AddPartitionsToTxnRequest.Builder(
-      transactionalId,
-      producerId,
-      producerEpoch,
-      List(createdTopicPartition, nonExistentTopic).asJava)
-      .build()
+    val request =
+      if (version < 4) {
+        AddPartitionsToTxnRequest.Builder.forClient(
+          transactionalId,
+          producerId,
+          producerEpoch,
+          List(createdTopicPartition, nonExistentTopic).asJava
+        ).build(version)
+      } else {
+        val topics = new AddPartitionsToTxnTopicCollection()
+        topics.add(new AddPartitionsToTxnTopic()
+          .setName(createdTopicPartition.topic)
+          
.setPartitions(Collections.singletonList(createdTopicPartition.partition)))
+        topics.add(new AddPartitionsToTxnTopic()
+          .setName(nonExistentTopic.topic)
+          
.setPartitions(Collections.singletonList(nonExistentTopic.partition)))
+
+        val transactions = new AddPartitionsToTxnTransactionCollection()
+        transactions.add(new AddPartitionsToTxnTransaction()
+          .setTransactionalId(transactionalId)
+          .setProducerId(producerId)
+          .setProducerEpoch(producerEpoch)
+          .setVerifyOnly(false)
+          .setTopics(topics))
+        
AddPartitionsToTxnRequest.Builder.forBroker(transactions).build(version)
+      }
 
     val leaderId = brokers.head.config.brokerId
     val response = connectAndReceive[AddPartitionsToTxnResponse](request, 
brokerSocketServer(leaderId))
+    
+    val errors = 
+      if (version < 4) 
+        response.errors.get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID) 
+      else 
+        response.errors.get(transactionalId)
+    
+    assertEquals(2, errors.size)
+
+    assertTrue(errors.containsKey(createdTopicPartition))
+    assertEquals(Errors.OPERATION_NOT_ATTEMPTED, 
errors.get(createdTopicPartition))
+
+    assertTrue(errors.containsKey(nonExistentTopic))
+    assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, 
errors.get(nonExistentTopic))
+  }
+  
+  @Test
+  def testOneSuccessOneErrorInBatchedRequest(): Unit = {
+    val tp0 = new TopicPartition(topic1, 0)
+    val transactionalId1 = "foobar"
+    val transactionalId2 = "barfoo" // "barfoo" maps to the same transaction 
coordinator
+    val producerId2 = 1000L
+    val producerEpoch2: Short = 0
+    
+    val txn2Topics = new AddPartitionsToTxnTopicCollection()
+    txn2Topics.add(new AddPartitionsToTxnTopic()
+      .setName(tp0.topic)
+      .setPartitions(Collections.singletonList(tp0.partition)))
+
+    val (coordinatorId, txn1) = setUpTransactions(transactionalId1, false, 
Set(tp0))
+
+    val transactions = new AddPartitionsToTxnTransactionCollection()
+    transactions.add(txn1)
+    transactions.add(new AddPartitionsToTxnTransaction()
+      .setTransactionalId(transactionalId2)
+      .setProducerId(producerId2)
+      .setProducerEpoch(producerEpoch2)
+      .setVerifyOnly(false)
+      .setTopics(txn2Topics))
+
+    val request = 
AddPartitionsToTxnRequest.Builder.forBroker(transactions).build()
+
+    val response = connectAndReceive[AddPartitionsToTxnResponse](request, 
brokerSocketServer(coordinatorId))
+
+    val errors = response.errors()
+    
+    val expectedErrors = Map(
+      transactionalId1 -> Collections.singletonMap(tp0, Errors.NONE),
+      transactionalId2 -> Collections.singletonMap(tp0, 
Errors.INVALID_PRODUCER_ID_MAPPING)
+    ).asJava
+
+    assertEquals(expectedErrors, errors)
+  }
 
-    assertEquals(2, response.errors.size)
+  @Test
+  def testVerifyOnly(): Unit = {
+    val tp0 = new TopicPartition(topic1, 0)
 
-    assertTrue(response.errors.containsKey(createdTopicPartition))
-    assertEquals(Errors.OPERATION_NOT_ATTEMPTED, 
response.errors.get(createdTopicPartition))
+    val transactionalId = "foobar"
+    val (coordinatorId, txn) = setUpTransactions(transactionalId, true, 
Set(tp0))
+
+    val transactions = new AddPartitionsToTxnTransactionCollection()
+    transactions.add(txn)
+    
+    val verifyRequest = 
AddPartitionsToTxnRequest.Builder.forBroker(transactions).build()
+
+    val verifyResponse = 
connectAndReceive[AddPartitionsToTxnResponse](verifyRequest, 
brokerSocketServer(coordinatorId))
+
+    val verifyErrors = verifyResponse.errors()
+
+    assertEquals(Collections.singletonMap(transactionalId, 
Collections.singletonMap(tp0, Errors.INVALID_TXN_STATE)), verifyErrors)
+  }
+  
+  private def setUpTransactions(transactionalId: String, verifyOnly: Boolean, 
partitions: Set[TopicPartition]): (Int, AddPartitionsToTxnTransaction) = {
+    val findCoordinatorRequest = new FindCoordinatorRequest.Builder(new 
FindCoordinatorRequestData().setKey(transactionalId).setKeyType(CoordinatorType.TRANSACTION.id)).build()
+    // First find coordinator request creates the state topic, then wait for 
transactional topics to be created.
+    connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, 
brokerSocketServer(brokers.head.config.brokerId))
+    TestUtils.waitForAllPartitionsMetadata(brokers, "__transaction_state", 50)
+    val findCoordinatorResponse = 
connectAndReceive[FindCoordinatorResponse](findCoordinatorRequest, 
brokerSocketServer(brokers.head.config.brokerId))
+    val coordinatorId = 
findCoordinatorResponse.data().coordinators().get(0).nodeId()
+
+    val initPidRequest = new InitProducerIdRequest.Builder(new 
InitProducerIdRequestData().setTransactionalId(transactionalId).setTransactionTimeoutMs(10000)).build()
+    val initPidResponse = 
connectAndReceive[InitProducerIdResponse](initPidRequest, 
brokerSocketServer(coordinatorId))
+
+    val producerId1 = initPidResponse.data().producerId()
+    val producerEpoch1 = initPidResponse.data().producerEpoch()
+
+    val txn1Topics = new AddPartitionsToTxnTopicCollection()
+    partitions.foreach { tp => 
+    txn1Topics.add(new AddPartitionsToTxnTopic()
+      .setName(tp.topic)
+      .setPartitions(Collections.singletonList(tp.partition)))
+    }
+
+    (coordinatorId, new AddPartitionsToTxnTransaction()
+      .setTransactionalId(transactionalId)
+      .setProducerId(producerId1)
+      .setProducerEpoch(producerEpoch1)
+      .setVerifyOnly(verifyOnly)
+      .setTopics(txn1Topics))
+  }
+}
 
-    assertTrue(response.errors.containsKey(nonExistentTopic))
-    assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, 
response.errors.get(nonExistentTopic))
+object AddPartitionsToTxnRequestServerTest {
+   def parameters: JStream[Arguments] = {
+    val arguments = mutable.ListBuffer[Arguments]()
+    ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions().forEach { version =>
+      Array("kraft", "zk").foreach { quorum =>
+        arguments += Arguments.of(quorum, version)
+      }
+    }
+    arguments.asJava.stream()
   }
 }
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index 0638bf36323..63fd69b61b3 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -40,6 +40,8 @@ import 
org.apache.kafka.common.errors.UnsupportedVersionException
 import org.apache.kafka.common.internals.{KafkaFutureImpl, Topic}
 import org.apache.kafka.common.memory.MemoryPool
 import org.apache.kafka.common.config.ConfigResource.Type.{BROKER, 
BROKER_LOGGER}
+import 
org.apache.kafka.common.message.AddPartitionsToTxnRequestData.{AddPartitionsToTxnTopic,
 AddPartitionsToTxnTopicCollection, AddPartitionsToTxnTransaction, 
AddPartitionsToTxnTransactionCollection}
+import 
org.apache.kafka.common.message.AddPartitionsToTxnResponseData.AddPartitionsToTxnResult
 import 
org.apache.kafka.common.message.AlterConfigsRequestData.{AlterConfigsResourceCollection
 => LAlterConfigsResourceCollection}
 import 
org.apache.kafka.common.message.AlterConfigsRequestData.{AlterConfigsResource 
=> LAlterConfigsResource}
 import 
org.apache.kafka.common.message.AlterConfigsRequestData.{AlterableConfigCollection
 => LAlterableConfigCollection}
@@ -1988,7 +1990,7 @@ class KafkaApisTest {
     val topic = "topic"
     addTopicToMetadataCache(topic, numPartitions = 2)
 
-    for (version <- ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion to 
ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion) {
+    for (version <- ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion to 3) {
 
       reset(replicaManager, clientRequestQuotaManager, requestChannel, 
txnCoordinator)
 
@@ -2002,7 +2004,7 @@ class KafkaApisTest {
       val partition = 1
       val topicPartition = new TopicPartition(topic, partition)
 
-      val addPartitionsToTxnRequest = new AddPartitionsToTxnRequest.Builder(
+      val addPartitionsToTxnRequest = 
AddPartitionsToTxnRequest.Builder.forClient(
         transactionalId,
         producerId,
         epoch,
@@ -2020,7 +2022,7 @@ class KafkaApisTest {
         ArgumentMatchers.eq(requestLocal)
       )).thenAnswer(_ => 
responseCallback.getValue.apply(Errors.PRODUCER_FENCED))
 
-      createKafkaApis().handleAddPartitionToTxnRequest(request, requestLocal)
+      createKafkaApis().handleAddPartitionsToTxnRequest(request, requestLocal)
 
       verify(requestChannel).sendResponse(
         ArgumentMatchers.eq(request),
@@ -2030,13 +2032,87 @@ class KafkaApisTest {
       val response = capturedResponse.getValue
 
       if (version < 2) {
-        assertEquals(Collections.singletonMap(topicPartition, 
Errors.INVALID_PRODUCER_EPOCH), response.errors())
+        assertEquals(Collections.singletonMap(topicPartition, 
Errors.INVALID_PRODUCER_EPOCH), 
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
       } else {
-        assertEquals(Collections.singletonMap(topicPartition, 
Errors.PRODUCER_FENCED), response.errors())
+        assertEquals(Collections.singletonMap(topicPartition, 
Errors.PRODUCER_FENCED), 
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
       }
     }
   }
 
+  @Test
+  def testBatchedAddPartitionsToTxnRequest(): Unit = {
+    val topic = "topic"
+    addTopicToMetadataCache(topic, numPartitions = 2)
+
+    val responseCallback: ArgumentCaptor[Errors => Unit] = 
ArgumentCaptor.forClass(classOf[Errors => Unit])
+    val verifyPartitionsCallback: ArgumentCaptor[AddPartitionsToTxnResult => 
Unit] = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnResult => Unit])
+
+    val transactionalId1 = "txnId1"
+    val transactionalId2 = "txnId2"
+    val producerId = 15L
+    val epoch = 0.toShort
+    
+    val tp0 = new TopicPartition(topic, 0)
+    val tp1 = new TopicPartition(topic, 1)
+
+    val addPartitionsToTxnRequest = 
AddPartitionsToTxnRequest.Builder.forBroker(
+      new AddPartitionsToTxnTransactionCollection(
+        List(new AddPartitionsToTxnTransaction()
+          .setTransactionalId(transactionalId1)
+          .setProducerId(producerId)
+          .setProducerEpoch(epoch)
+          .setVerifyOnly(false)
+          .setTopics(new AddPartitionsToTxnTopicCollection(
+            Collections.singletonList(new AddPartitionsToTxnTopic()
+              .setName(tp0.topic)
+              .setPartitions(Collections.singletonList(tp0.partition))
+            ).iterator())
+          ), new AddPartitionsToTxnTransaction()
+          .setTransactionalId(transactionalId2)
+          .setProducerId(producerId)
+          .setProducerEpoch(epoch)
+          .setVerifyOnly(true)
+          .setTopics(new AddPartitionsToTxnTopicCollection(
+            Collections.singletonList(new AddPartitionsToTxnTopic()
+              .setName(tp1.topic)
+              .setPartitions(Collections.singletonList(tp1.partition))
+            ).iterator())
+          )
+        ).asJava.iterator()
+      )
+    ).build(4.toShort)
+    val request = buildRequest(addPartitionsToTxnRequest)
+
+    val requestLocal = RequestLocal.withThreadConfinedCaching
+    when(txnCoordinator.handleAddPartitionsToTransaction(
+      ArgumentMatchers.eq(transactionalId1),
+      ArgumentMatchers.eq(producerId),
+      ArgumentMatchers.eq(epoch),
+      ArgumentMatchers.eq(Set(tp0)),
+      responseCallback.capture(),
+      ArgumentMatchers.eq(requestLocal)
+    )).thenAnswer(_ => responseCallback.getValue.apply(Errors.NONE))
+
+    when(txnCoordinator.handleVerifyPartitionsInTransaction(
+      ArgumentMatchers.eq(transactionalId2),
+      ArgumentMatchers.eq(producerId),
+      ArgumentMatchers.eq(epoch),
+      ArgumentMatchers.eq(Set(tp1)),
+      verifyPartitionsCallback.capture(),
+    )).thenAnswer(_ => 
verifyPartitionsCallback.getValue.apply(AddPartitionsToTxnResponse.resultForTransaction(transactionalId2,
 Map(tp1 -> Errors.PRODUCER_FENCED).asJava)))
+
+    createKafkaApis().handleAddPartitionsToTxnRequest(request, requestLocal)
+
+    val response = verifyNoThrottling[AddPartitionsToTxnResponse](request)
+    
+    val expectedErrors = Map(
+      transactionalId1 -> Collections.singletonMap(tp0, Errors.NONE),
+      transactionalId2 -> Collections.singletonMap(tp1, Errors.PRODUCER_FENCED)
+    ).asJava
+    
+    assertEquals(expectedErrors, response.errors())
+  }
+
   @Test
   def 
shouldReplaceProducerFencedWithInvalidProducerEpochInEndTxnWithOlderClient(): 
Unit = {
     val topic = "topic"
@@ -2151,17 +2227,17 @@ class KafkaApisTest {
       reset(replicaManager, clientRequestQuotaManager, requestChannel)
 
       val invalidTopicPartition = new TopicPartition(topic, invalidPartitionId)
-      val addPartitionsToTxnRequest = new AddPartitionsToTxnRequest.Builder(
+      val addPartitionsToTxnRequest = 
AddPartitionsToTxnRequest.Builder.forClient(
         "txnlId", 15L, 0.toShort, List(invalidTopicPartition).asJava
       ).build()
       val request = buildRequest(addPartitionsToTxnRequest)
 
       
when(clientRequestQuotaManager.maybeRecordAndGetThrottleTimeMs(any[RequestChannel.Request](),
         any[Long])).thenReturn(0)
-      createKafkaApis().handleAddPartitionToTxnRequest(request, 
RequestLocal.withThreadConfinedCaching)
+      createKafkaApis().handleAddPartitionsToTxnRequest(request, 
RequestLocal.withThreadConfinedCaching)
 
       val response = verifyNoThrottling[AddPartitionsToTxnResponse](request)
-      assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, 
response.errors().get(invalidTopicPartition))
+      assertEquals(Errors.UNKNOWN_TOPIC_OR_PARTITION, 
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID).get(invalidTopicPartition))
     }
 
     checkInvalidPartition(-1)
@@ -2177,13 +2253,13 @@ class KafkaApisTest {
   @Test
   def 
shouldThrowUnsupportedVersionExceptionOnHandleAddPartitionsToTxnRequestWhenInterBrokerProtocolNotSupported():
 Unit = {
     assertThrows(classOf[UnsupportedVersionException],
-      () => 
createKafkaApis(IBP_0_10_2_IV0).handleAddPartitionToTxnRequest(null, 
RequestLocal.withThreadConfinedCaching))
+      () => 
createKafkaApis(IBP_0_10_2_IV0).handleAddPartitionsToTxnRequest(null, 
RequestLocal.withThreadConfinedCaching))
   }
 
   @Test
   def 
shouldThrowUnsupportedVersionExceptionOnHandleTxnOffsetCommitRequestWhenInterBrokerProtocolNotSupported():
 Unit = {
     assertThrows(classOf[UnsupportedVersionException],
-      () => 
createKafkaApis(IBP_0_10_2_IV0).handleAddPartitionToTxnRequest(null, 
RequestLocal.withThreadConfinedCaching))
+      () => 
createKafkaApis(IBP_0_10_2_IV0).handleAddPartitionsToTxnRequest(null, 
RequestLocal.withThreadConfinedCaching))
   }
 
   @Test
diff --git a/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala 
b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
index 4d638e138cd..9bb6e3cee9a 100644
--- a/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
+++ b/core/src/test/scala/unit/kafka/server/RequestQuotaTest.scala
@@ -427,7 +427,7 @@ class RequestQuotaTest extends BaseRequestTest {
           OffsetsForLeaderEpochRequest.Builder.forConsumer(epochs)
 
         case ApiKeys.ADD_PARTITIONS_TO_TXN =>
-          new AddPartitionsToTxnRequest.Builder("test-transactional-id", 1, 0, 
List(tp).asJava)
+          AddPartitionsToTxnRequest.Builder.forClient("test-transactional-id", 
1, 0, List(tp).asJava)
 
         case ApiKeys.ADD_OFFSETS_TO_TXN =>
           new AddOffsetsToTxnRequest.Builder(new AddOffsetsToTxnRequestData()
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala 
b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index dae338a3b08..fd463420d17 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -374,7 +374,6 @@ object TestUtils extends Logging {
       props.put(KafkaConfig.RackProp, nodeId.toString)
       props.put(KafkaConfig.ReplicaSelectorClassProp, 
"org.apache.kafka.common.replica.RackAwareReplicaSelector")
     }
-
     props
   }
 

Reply via email to