dajac commented on code in PR #13231:
URL: https://github.com/apache/kafka/pull/13231#discussion_r1124792127
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java:
##########
@@ -34,22 +43,38 @@ public class AddPartitionsToTxnRequest extends
AbstractRequest {
private final AddPartitionsToTxnRequestData data;
- private List<TopicPartition> cachedPartitions = null;
-
public static class Builder extends
AbstractRequest.Builder<AddPartitionsToTxnRequest> {
public final AddPartitionsToTxnRequestData data;
+
+ public static Builder forClient(String transactionalId,
+ long producerId,
+ short producerEpoch,
+ List<TopicPartition> partitions) {
+
+ AddPartitionsToTxnTopicCollection topics =
buildTxnTopicCollection(partitions);
+
+ return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(),
+ (short) 3,
Review Comment:
nit: Should we put `(short) 3` on the previous line to be consistent with
how you did it at L66?
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java:
##########
@@ -34,22 +43,38 @@ public class AddPartitionsToTxnRequest extends
AbstractRequest {
private final AddPartitionsToTxnRequestData data;
- private List<TopicPartition> cachedPartitions = null;
-
public static class Builder extends
AbstractRequest.Builder<AddPartitionsToTxnRequest> {
public final AddPartitionsToTxnRequestData data;
+
+ public static Builder forClient(String transactionalId,
+ long producerId,
+ short producerEpoch,
+ List<TopicPartition> partitions) {
+
+ AddPartitionsToTxnTopicCollection topics =
buildTxnTopicCollection(partitions);
+
+ return new Builder(ApiKeys.ADD_PARTITIONS_TO_TXN.oldestVersion(),
+ (short) 3,
+ new AddPartitionsToTxnRequestData()
+ .setV3AndBelowTransactionalId(transactionalId)
+ .setV3AndBelowProducerId(producerId)
+ .setV3AndBelowProducerEpoch(producerEpoch)
+ .setV3AndBelowTopics(topics));
+ }
+
+ public static Builder
forBroker(AddPartitionsToTxnTransactionCollection transactions) {
+ return new Builder((short) 4,
ApiKeys.ADD_PARTITIONS_TO_TXN.latestVersion(),
+ new AddPartitionsToTxnRequestData()
+ .setTransactions(transactions));
+ }
+
+ public Builder(short minVersion, short maxVersion,
AddPartitionsToTxnRequestData data) {
Review Comment:
nit: Do we still use this constructor anywhere? It may be good to make it
private or package private to ensure that `forClient` or `forBroker` is used.
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java:
##########
@@ -48,29 +51,51 @@ public class AddPartitionsToTxnResponse extends
AbstractResponse {
private final AddPartitionsToTxnResponseData data;
- private Map<TopicPartition, Errors> cachedErrorsMap = null;
+ public static final String V3_AND_BELOW_TXN_ID = "";
public AddPartitionsToTxnResponse(AddPartitionsToTxnResponseData data) {
super(ApiKeys.ADD_PARTITIONS_TO_TXN);
this.data = data;
}
- public AddPartitionsToTxnResponse(int throttleTimeMs, Map<TopicPartition,
Errors> errors) {
- super(ApiKeys.ADD_PARTITIONS_TO_TXN);
+ @Override
+ public int throttleTimeMs() {
+ return data.throttleTimeMs();
+ }
+
+ @Override
+ public void maybeSetThrottleTimeMs(int throttleTimeMs) {
+ data.setThrottleTimeMs(throttleTimeMs);
+ }
+
+ public Map<String, Map<TopicPartition, Errors>> errors() {
+ Map<String, Map<TopicPartition, Errors>> errorsMap = new HashMap<>();
+
+ if (this.data.resultsByTopicV3AndBelow().size() != 0) {
Review Comment:
nit: I think that we usually prefer using `isEmpty()`.
##########
clients/src/test/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponseTest.java:
##########
@@ -58,42 +65,75 @@ public void setUp() {
errorsMap.put(tp2, errorTwo);
}
- @Test
- public void testConstructorWithErrorResponse() {
- AddPartitionsToTxnResponse response = new
AddPartitionsToTxnResponse(throttleTimeMs, errorsMap);
-
- assertEquals(expectedErrorCounts, response.errorCounts());
- assertEquals(throttleTimeMs, response.throttleTimeMs());
- }
-
- @Test
- public void testParse() {
-
+ @ParameterizedTest
+ @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
+ public void testParse(short version) {
AddPartitionsToTxnTopicResultCollection topicCollection = new
AddPartitionsToTxnTopicResultCollection();
AddPartitionsToTxnTopicResult topicResult = new
AddPartitionsToTxnTopicResult();
topicResult.setName(topicOne);
- topicResult.results().add(new AddPartitionsToTxnPartitionResult()
- .setErrorCode(errorOne.code())
+ topicResult.resultsByPartition().add(new
AddPartitionsToTxnPartitionResult()
+ .setPartitionErrorCode(errorOne.code())
.setPartitionIndex(partitionOne));
- topicResult.results().add(new AddPartitionsToTxnPartitionResult()
- .setErrorCode(errorTwo.code())
+ topicResult.resultsByPartition().add(new
AddPartitionsToTxnPartitionResult()
+ .setPartitionErrorCode(errorTwo.code())
.setPartitionIndex(partitionTwo));
topicCollection.add(topicResult);
+
+ if (version < 4) {
+ AddPartitionsToTxnResponseData data = new
AddPartitionsToTxnResponseData()
+ .setResultsByTopicV3AndBelow(topicCollection)
+ .setThrottleTimeMs(throttleTimeMs);
+ AddPartitionsToTxnResponse response = new
AddPartitionsToTxnResponse(data);
- AddPartitionsToTxnResponseData data = new
AddPartitionsToTxnResponseData()
- .setResults(topicCollection)
-
.setThrottleTimeMs(throttleTimeMs);
- AddPartitionsToTxnResponse response = new
AddPartitionsToTxnResponse(data);
-
- for (short version : ApiKeys.ADD_PARTITIONS_TO_TXN.allVersions()) {
AddPartitionsToTxnResponse parsedResponse =
AddPartitionsToTxnResponse.parse(response.serialize(version), version);
assertEquals(expectedErrorCounts, parsedResponse.errorCounts());
assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs());
assertEquals(version >= 1,
parsedResponse.shouldClientThrottle(version));
+ } else {
+ AddPartitionsToTxnResultCollection results = new
AddPartitionsToTxnResultCollection();
+ results.add(new
AddPartitionsToTxnResult().setTransactionalId("txn1").setTopicResults(topicCollection));
+
+ // Create another transaction with new name and errorOne for a
single partition.
+ Map<TopicPartition, Errors> txnTwoExpectedErrors =
Collections.singletonMap(tp2, errorOne);
+
results.add(AddPartitionsToTxnResponse.resultForTransaction("txn2",
txnTwoExpectedErrors));
+
+ AddPartitionsToTxnResponseData data = new
AddPartitionsToTxnResponseData()
+ .setResultsByTransaction(results)
+ .setThrottleTimeMs(throttleTimeMs);
+ AddPartitionsToTxnResponse response = new
AddPartitionsToTxnResponse(data);
+
+ Map<Errors, Integer> newExpectedErrorCounts = new HashMap<>();
+ newExpectedErrorCounts.put(Errors.NONE, 1); // top level error
+ newExpectedErrorCounts.put(errorOne, 2);
+ newExpectedErrorCounts.put(errorTwo, 1);
+
+ AddPartitionsToTxnResponse parsedResponse =
AddPartitionsToTxnResponse.parse(response.serialize(version), version);
+ assertEquals(txnTwoExpectedErrors,
errorsForTransaction(response.getTransactionTopicResults("txn2")));
+ assertEquals(newExpectedErrorCounts, parsedResponse.errorCounts());
+ assertEquals(throttleTimeMs, parsedResponse.throttleTimeMs());
+ assertTrue(parsedResponse.shouldClientThrottle(version));
}
}
+
+ @Test
+ public void testBatchedErrors() {
+ Map<TopicPartition, Errors> txn1Errors = Collections.singletonMap(tp1,
errorOne);
+ Map<TopicPartition, Errors> txn2Errors = Collections.singletonMap(tp1,
errorOne);
+
+ AddPartitionsToTxnResult transaction1 =
AddPartitionsToTxnResponse.resultForTransaction("txn1", txn1Errors);
+ AddPartitionsToTxnResult transaction2 =
AddPartitionsToTxnResponse.resultForTransaction("txn2", txn2Errors);
+
+ AddPartitionsToTxnResultCollection results = new
AddPartitionsToTxnResultCollection();
+ results.add(transaction1);
+ results.add(transaction2);
+
+ AddPartitionsToTxnResponse response = new
AddPartitionsToTxnResponse(new
AddPartitionsToTxnResponseData().setResultsByTransaction(results));
+
+ assertEquals(txn1Errors,
errorsForTransaction(response.getTransactionTopicResults("txn1")));
+ assertEquals(txn2Errors,
errorsForTransaction(response.getTransactionTopicResults("txn2")));
+ }
Review Comment:
nit: Should we add a test for `errors()`?
##########
core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala:
##########
@@ -317,6 +322,34 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
}
}
+
+ def handleVerifyPartitionsInTransaction(transactionalId: String,
+ producerId: Long,
+ producerEpoch: Short,
+ partitions:
collection.Set[TopicPartition],
+ responseCallback:
VerifyPartitionsCallback): Unit = {
+ if (transactionalId == null || transactionalId.isEmpty) {
+ debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for
$transactionalId's AddPartitions request")
+
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava))
+ } else {
+ val result: ApiResult[(Int, TransactionMetadata)] =
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
+
+ result match {
+ case Left(err) =>
+ debug(s"Returning $err error code to client for $transactionalId's
AddPartitions request")
Review Comment:
nit: Should update this line to mention that we are validating only here?
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java:
##########
@@ -80,45 +105,44 @@ topicName, new
AddPartitionsToTxnPartitionResultCollection()
AddPartitionsToTxnTopicResultCollection topicCollection = new
AddPartitionsToTxnTopicResultCollection();
for (Map.Entry<String, AddPartitionsToTxnPartitionResultCollection>
entry : resultMap.entrySet()) {
topicCollection.add(new AddPartitionsToTxnTopicResult()
- .setName(entry.getKey())
- .setResults(entry.getValue()));
+ .setName(entry.getKey())
+ .setResultsByPartition(entry.getValue()));
}
-
- this.data = new AddPartitionsToTxnResponseData()
- .setThrottleTimeMs(throttleTimeMs)
- .setResults(topicCollection);
+ return topicCollection;
}
- @Override
- public int throttleTimeMs() {
- return data.throttleTimeMs();
+ public static AddPartitionsToTxnResult resultForTransaction(String
transactionalId, Map<TopicPartition, Errors> errors) {
+ return new
AddPartitionsToTxnResult().setTransactionalId(transactionalId).setTopicResults(topicCollectionForErrors(errors));
}
- @Override
- public void maybeSetThrottleTimeMs(int throttleTimeMs) {
- data.setThrottleTimeMs(throttleTimeMs);
+ public AddPartitionsToTxnTopicResultCollection
getTransactionTopicResults(String transactionalId) {
+ return
data.resultsByTransaction().find(transactionalId).topicResults();
}
- public Map<TopicPartition, Errors> errors() {
- if (cachedErrorsMap != null) {
- return cachedErrorsMap;
- }
-
- cachedErrorsMap = new HashMap<>();
-
- for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) {
- for (AddPartitionsToTxnPartitionResult partitionResult :
topicResult.results()) {
- cachedErrorsMap.put(new TopicPartition(
- topicResult.name(), partitionResult.partitionIndex()),
- Errors.forCode(partitionResult.errorCode()));
+ public static Map<TopicPartition, Errors>
errorsForTransaction(AddPartitionsToTxnTopicResultCollection topicCollection) {
+ Map<TopicPartition, Errors> topicResults = new HashMap<>();
+ for (AddPartitionsToTxnTopicResult topicResult : topicCollection) {
+ for (AddPartitionsToTxnPartitionResult partitionResult :
topicResult.resultsByPartition()) {
+ topicResults.put(
+ new TopicPartition(topicResult.name(),
partitionResult.partitionIndex()),
Errors.forCode(partitionResult.partitionErrorCode()));
}
}
- return cachedErrorsMap;
+ return topicResults;
}
@Override
public Map<Errors, Integer> errorCounts() {
- return errorCounts(errors().values());
+ List<Errors> allErrors = new ArrayList<>();
+
+ // If we are not using this field, we have request 4 or later
+ if (this.data.resultsByTopicV3AndBelow().size() == 0) {
+ allErrors.add(Errors.forCode(data.errorCode()));
Review Comment:
nit: Should we use `updateErrorCounts` from `AbstractResponse` instead of
creating `allErrors`?
##########
core/src/main/scala/kafka/server/KafkaApis.scala:
##########
@@ -2384,66 +2386,111 @@ class KafkaApis(val requestChannel: RequestChannel,
if (config.interBrokerProtocolVersion.isLessThan(version))
throw new UnsupportedVersionException(s"inter.broker.protocol.version:
${config.interBrokerProtocolVersion.version} is less than the required version:
${version.version}")
}
-
- def handleAddPartitionToTxnRequest(request: RequestChannel.Request,
requestLocal: RequestLocal): Unit = {
+ def handleAddPartitionsToTxnRequest(request: RequestChannel.Request,
requestLocal: RequestLocal): Unit = {
ensureInterBrokerVersion(IBP_0_11_0_IV0)
- val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest]
- val transactionalId = addPartitionsToTxnRequest.data.transactionalId
- val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala
- if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID,
transactionalId))
- requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
- addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs,
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception))
- else {
- val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
- val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
- val authorizedPartitions = mutable.Set[TopicPartition]()
-
- val authorizedTopics = authHelper.filterByAuthorized(request.context,
WRITE, TOPIC,
- partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic)
- for (topicPartition <- partitionsToAdd) {
- if (!authorizedTopics.contains(topicPartition.topic))
- unauthorizedTopicErrors += topicPartition ->
Errors.TOPIC_AUTHORIZATION_FAILED
- else if (!metadataCache.contains(topicPartition))
- nonExistingTopicErrors += topicPartition ->
Errors.UNKNOWN_TOPIC_OR_PARTITION
- else
- authorizedPartitions.add(topicPartition)
+ val addPartitionsToTxnRequest =
+ if (request.context.apiVersion() < 4)
+ request.body[AddPartitionsToTxnRequest].normalizeRequest()
+ else
+ request.body[AddPartitionsToTxnRequest]
+ val version = addPartitionsToTxnRequest.version
+ val responses = new AddPartitionsToTxnResultCollection()
+ val partitionsByTransaction =
addPartitionsToTxnRequest.partitionsByTransaction()
+
+ // Newer versions of the request should only come from other brokers.
+ if (version >= 4) authHelper.authorizeClusterOperation(request,
CLUSTER_ACTION)
+
+ // V4 requests introduced batches of transactions. We need all
transactions to be handled before sending the
+ // response so there are a few differences in handling errors and sending
responses.
+ def createResponse(requestThrottleMs: Int): AbstractResponse = {
+ if (version < 4) {
+ // There will only be one response in data. Add it to the response
data object.
+ val data = new AddPartitionsToTxnResponseData()
+ responses.forEach { result =>
+ data.setResultsByTopicV3AndBelow(result.topicResults())
+ data.setThrottleTimeMs(requestThrottleMs)
+ }
+ new AddPartitionsToTxnResponse(data)
+ } else {
+ new AddPartitionsToTxnResponse(new
AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses))
}
+ }
- if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty)
{
- // Any failed partition check causes the entire request to fail. We
send the appropriate error codes for the
- // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error
code for the partitions which succeeded
- // the authorization check to indicate that they were not added to the
transaction.
- val partitionErrors = unauthorizedTopicErrors ++
nonExistingTopicErrors ++
- authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
- requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
- new AddPartitionsToTxnResponse(requestThrottleMs,
partitionErrors.asJava))
+ val txns = addPartitionsToTxnRequest.data.transactions
+ def addResultAndMaybeSendResponse(result: AddPartitionsToTxnResult): Unit
= {
+ val canSend = responses.synchronized {
+ responses.add(result)
+ responses.size() == txns.size()
+ }
+ if (canSend) {
+ requestHelper.sendResponseMaybeThrottle(request, createResponse)
+ }
+ }
+
+ txns.forEach { transaction =>
+ val transactionalId = transaction.transactionalId
+ val partitionsToAdd =
partitionsByTransaction.get(transactionalId).asScala
+
+ // Versions < 4 come from clients and must be authorized to write for
the given transaction and for the given topics.
+ if (version < 4 && !authHelper.authorize(request.context, WRITE,
TRANSACTIONAL_ID, transactionalId)) {
+
addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId,
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED))
} else {
- def sendResponseCallback(error: Errors): Unit = {
- def createResponse(requestThrottleMs: Int): AbstractResponse = {
- val finalError =
- if (addPartitionsToTxnRequest.version < 2 && error ==
Errors.PRODUCER_FENCED) {
+ val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
+ val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
+ val authorizedPartitions = mutable.Set[TopicPartition]()
+
+ // Only request versions less than 4 need write authorization since
they come from clients.
+ val authorizedTopics =
+ if (version < 4)
+ authHelper.filterByAuthorized(request.context, WRITE, TOPIC,
partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic)
+ else
+ partitionsToAdd.map(_.topic).toSet
+ for (topicPartition <- partitionsToAdd) {
+ if (!authorizedTopics.contains(topicPartition.topic))
+ unauthorizedTopicErrors += topicPartition ->
Errors.TOPIC_AUTHORIZATION_FAILED
+ else if (!metadataCache.contains(topicPartition))
+ nonExistingTopicErrors += topicPartition ->
Errors.UNKNOWN_TOPIC_OR_PARTITION
+ else
+ authorizedPartitions.add(topicPartition)
+ }
+
+ if (unauthorizedTopicErrors.nonEmpty ||
nonExistingTopicErrors.nonEmpty) {
+ // Any failed partition check causes the entire transaction to fail.
We send the appropriate error codes for the
+ // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error
code for the partitions which succeeded
+ // the authorization check to indicate that they were not added to
the transaction.
+ val partitionErrors = unauthorizedTopicErrors ++
nonExistingTopicErrors ++
+ authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
+
addResultAndMaybeSendResponse(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
partitionErrors.asJava))
+ } else {
+ def sendResponseCallback(error: Errors): Unit = {
+ val finalError = {
+ if (version < 2 && error == Errors.PRODUCER_FENCED) {
// For older clients, they could not understand the new
PRODUCER_FENCED error code,
// so we need to return the old INVALID_PRODUCER_EPOCH to have
the same client handling logic.
Errors.INVALID_PRODUCER_EPOCH
} else {
error
}
-
- val responseBody: AddPartitionsToTxnResponse = new
AddPartitionsToTxnResponse(requestThrottleMs,
- partitionsToAdd.map{tp => (tp, finalError)}.toMap.asJava)
- trace(s"Completed $transactionalId's AddPartitionsToTxnRequest
with partitions $partitionsToAdd: errors: $error from client
${request.header.clientId}")
- responseBody
+ }
+
addResultAndMaybeSendResponse(addPartitionsToTxnRequest.errorResponseForTransaction(transactionalId,
finalError))
}
- requestHelper.sendResponseMaybeThrottle(request, createResponse)
- }
- txnCoordinator.handleAddPartitionsToTransaction(transactionalId,
- addPartitionsToTxnRequest.data.producerId,
- addPartitionsToTxnRequest.data.producerEpoch,
- authorizedPartitions,
- sendResponseCallback,
- requestLocal)
+ if (!transaction.verifyOnly) {
+ txnCoordinator.handleAddPartitionsToTransaction(transactionalId,
+ transaction.producerId,
+ transaction.producerEpoch,
+ authorizedPartitions,
+ sendResponseCallback,
+ requestLocal)
+ } else {
+ txnCoordinator.handleVerifyPartitionsInTransaction(transactionalId,
+ transaction.producerId,
+ transaction.producerEpoch,
+ authorizedPartitions,
+ addResultAndMaybeSendResponse)
+ }
+ }
Review Comment:
👍🏻
##########
core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala:
##########
@@ -314,6 +316,32 @@ class TransactionCoordinatorTest {
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
}
+ @Test
+ def
shouldRespondWithErrorsNoneOnAddPartitionWhenOngoingVerifyOnlyAndPartitionsTheSame():
Unit = {
+
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+ .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
+ new TransactionMetadata(transactionalId, 0, 0, 0,
RecordBatch.NO_PRODUCER_EPOCH, 0, Ongoing, partitions, 0, 0)))))
+
+ coordinator.handleVerifyPartitionsInTransaction(transactionalId, 0L, 0,
partitions, verifyPartitionsInTxnCallback)
+ assertEquals(Errors.NONE, error)
+
verify(transactionManager).getTransactionState(ArgumentMatchers.eq(transactionalId))
+ }
+
+ @Test
+ def shouldRespondWithInvalidTxnStateWhenVerifyOnlyAndPartitionNotPresent():
Unit = {
+
when(transactionManager.getTransactionState(ArgumentMatchers.eq(transactionalId)))
+ .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch,
+ new TransactionMetadata(transactionalId, 0, 0, 0,
RecordBatch.NO_PRODUCER_EPOCH, 0, Empty, partitions, 0, 0)))))
+
Review Comment:
nit: Extra empty line.
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnRequest.java:
##########
@@ -103,26 +113,82 @@ public AddPartitionsToTxnRequest(final
AddPartitionsToTxnRequestData data, short
this.data = data;
}
- public List<TopicPartition> partitions() {
- if (cachedPartitions != null) {
- return cachedPartitions;
- }
- cachedPartitions = Builder.getPartitions(data);
- return cachedPartitions;
- }
-
@Override
public AddPartitionsToTxnRequestData data() {
return data;
}
@Override
public AddPartitionsToTxnResponse getErrorResponse(int throttleTimeMs,
Throwable e) {
- final HashMap<TopicPartition, Errors> errors = new HashMap<>();
- for (TopicPartition partition : partitions()) {
- errors.put(partition, Errors.forException(e));
+ Errors error = Errors.forException(e);
+ AddPartitionsToTxnResponseData response = new
AddPartitionsToTxnResponseData();
+ if (version() < 4) {
+
response.setResultsByTopicV3AndBelow(errorResponseForTopics(data.v3AndBelowTopics(),
error));
+ } else {
+ AddPartitionsToTxnResultCollection results = new
AddPartitionsToTxnResultCollection();
+ response.setResultsByTransaction(results);
Review Comment:
nit: I think that you can remove those two lines. The collection should be
initialized by default.
##########
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##########
@@ -2030,13 +2032,93 @@ class KafkaApisTest {
val response = capturedResponse.getValue
if (version < 2) {
- assertEquals(Collections.singletonMap(topicPartition,
Errors.INVALID_PRODUCER_EPOCH), response.errors())
+ assertEquals(Collections.singletonMap(topicPartition,
Errors.INVALID_PRODUCER_EPOCH),
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
} else {
- assertEquals(Collections.singletonMap(topicPartition,
Errors.PRODUCER_FENCED), response.errors())
+ assertEquals(Collections.singletonMap(topicPartition,
Errors.PRODUCER_FENCED),
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
}
}
}
+ @Test
+ def testBatchedRequest(): Unit = {
+ val topic = "topic"
+ addTopicToMetadataCache(topic, numPartitions = 2)
+
+ val capturedResponse: ArgumentCaptor[AddPartitionsToTxnResponse] =
ArgumentCaptor.forClass(classOf[AddPartitionsToTxnResponse])
+ val responseCallback: ArgumentCaptor[Errors => Unit] =
ArgumentCaptor.forClass(classOf[Errors => Unit])
+ val verifyPartitionsCallback: ArgumentCaptor[AddPartitionsToTxnResult =>
Unit] = ArgumentCaptor.forClass(classOf[AddPartitionsToTxnResult => Unit])
+
+ val transactionalId1 = "txnId1"
+ val transactionalId2 = "txnId2"
+ val producerId = 15L
+ val epoch = 0.toShort
+
+ val tp0 = new TopicPartition(topic, 0)
+ val tp1 = new TopicPartition(topic, 1)
+
+ val addPartitionsToTxnRequest =
AddPartitionsToTxnRequest.Builder.forBroker(
+ new AddPartitionsToTxnTransactionCollection(
+ List(new AddPartitionsToTxnTransaction()
+ .setTransactionalId(transactionalId1)
+ .setProducerId(producerId)
+ .setProducerEpoch(epoch)
+ .setVerifyOnly(false)
+ .setTopics(new AddPartitionsToTxnTopicCollection(
+ Collections.singletonList(new AddPartitionsToTxnTopic()
+ .setName(tp0.topic)
+ .setPartitions(Collections.singletonList(tp0.partition))
+ ).iterator())
+ ), new AddPartitionsToTxnTransaction()
+ .setTransactionalId(transactionalId2)
+ .setProducerId(producerId)
+ .setProducerEpoch(epoch)
+ .setVerifyOnly(true)
+ .setTopics(new AddPartitionsToTxnTopicCollection(
+ Collections.singletonList(new AddPartitionsToTxnTopic()
+ .setName(tp1.topic)
+ .setPartitions(Collections.singletonList(tp1.partition))
+ ).iterator())
+ )
+ ).asJava.iterator()
+ )
+ ).build(4.toShort)
+ val request = buildRequest(addPartitionsToTxnRequest)
+
+ val requestLocal = RequestLocal.withThreadConfinedCaching
+ when(txnCoordinator.handleAddPartitionsToTransaction(
+ ArgumentMatchers.eq(transactionalId1),
+ ArgumentMatchers.eq(producerId),
+ ArgumentMatchers.eq(epoch),
+ ArgumentMatchers.eq(Set(tp0)),
+ responseCallback.capture(),
+ ArgumentMatchers.eq(requestLocal)
+ )).thenAnswer(_ => responseCallback.getValue.apply(Errors.NONE))
+
+ when(txnCoordinator.handleVerifyPartitionsInTransaction(
+ ArgumentMatchers.eq(transactionalId2),
+ ArgumentMatchers.eq(producerId),
+ ArgumentMatchers.eq(epoch),
+ ArgumentMatchers.eq(Set(tp1)),
+ verifyPartitionsCallback.capture(),
+ )).thenAnswer(_ =>
verifyPartitionsCallback.getValue.apply(AddPartitionsToTxnResponse.resultForTransaction(transactionalId2,
Map(tp1 -> Errors.PRODUCER_FENCED).asJava)))
+
+ createKafkaApis().handleAddPartitionsToTxnRequest(request, requestLocal)
+
+ verify(requestChannel).sendResponse(
+ ArgumentMatchers.eq(request),
+ capturedResponse.capture(),
+ ArgumentMatchers.eq(None)
+ )
+ val response = capturedResponse.getValue
Review Comment:
nit: You can use `verifyNoThrottling`.
##########
clients/src/main/java/org/apache/kafka/common/requests/AddPartitionsToTxnResponse.java:
##########
@@ -80,45 +105,44 @@ topicName, new
AddPartitionsToTxnPartitionResultCollection()
AddPartitionsToTxnTopicResultCollection topicCollection = new
AddPartitionsToTxnTopicResultCollection();
for (Map.Entry<String, AddPartitionsToTxnPartitionResultCollection>
entry : resultMap.entrySet()) {
topicCollection.add(new AddPartitionsToTxnTopicResult()
- .setName(entry.getKey())
- .setResults(entry.getValue()));
+ .setName(entry.getKey())
+ .setResultsByPartition(entry.getValue()));
}
-
- this.data = new AddPartitionsToTxnResponseData()
- .setThrottleTimeMs(throttleTimeMs)
- .setResults(topicCollection);
+ return topicCollection;
}
- @Override
- public int throttleTimeMs() {
- return data.throttleTimeMs();
+ public static AddPartitionsToTxnResult resultForTransaction(String
transactionalId, Map<TopicPartition, Errors> errors) {
+ return new
AddPartitionsToTxnResult().setTransactionalId(transactionalId).setTopicResults(topicCollectionForErrors(errors));
}
- @Override
- public void maybeSetThrottleTimeMs(int throttleTimeMs) {
- data.setThrottleTimeMs(throttleTimeMs);
+ public AddPartitionsToTxnTopicResultCollection
getTransactionTopicResults(String transactionalId) {
+ return
data.resultsByTransaction().find(transactionalId).topicResults();
}
- public Map<TopicPartition, Errors> errors() {
- if (cachedErrorsMap != null) {
- return cachedErrorsMap;
- }
-
- cachedErrorsMap = new HashMap<>();
-
- for (AddPartitionsToTxnTopicResult topicResult : this.data.results()) {
- for (AddPartitionsToTxnPartitionResult partitionResult :
topicResult.results()) {
- cachedErrorsMap.put(new TopicPartition(
- topicResult.name(), partitionResult.partitionIndex()),
- Errors.forCode(partitionResult.errorCode()));
+ public static Map<TopicPartition, Errors>
errorsForTransaction(AddPartitionsToTxnTopicResultCollection topicCollection) {
+ Map<TopicPartition, Errors> topicResults = new HashMap<>();
+ for (AddPartitionsToTxnTopicResult topicResult : topicCollection) {
+ for (AddPartitionsToTxnPartitionResult partitionResult :
topicResult.resultsByPartition()) {
+ topicResults.put(
+ new TopicPartition(topicResult.name(),
partitionResult.partitionIndex()),
Errors.forCode(partitionResult.partitionErrorCode()));
}
}
- return cachedErrorsMap;
+ return topicResults;
}
@Override
public Map<Errors, Integer> errorCounts() {
- return errorCounts(errors().values());
+ List<Errors> allErrors = new ArrayList<>();
+
+ // If we are not using this field, we have request 4 or later
+ if (this.data.resultsByTopicV3AndBelow().size() == 0) {
Review Comment:
nit: `isEmpty()`?
##########
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##########
@@ -2030,13 +2032,93 @@ class KafkaApisTest {
val response = capturedResponse.getValue
if (version < 2) {
- assertEquals(Collections.singletonMap(topicPartition,
Errors.INVALID_PRODUCER_EPOCH), response.errors())
+ assertEquals(Collections.singletonMap(topicPartition,
Errors.INVALID_PRODUCER_EPOCH),
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
} else {
- assertEquals(Collections.singletonMap(topicPartition,
Errors.PRODUCER_FENCED), response.errors())
+ assertEquals(Collections.singletonMap(topicPartition,
Errors.PRODUCER_FENCED),
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID))
}
}
}
+ @Test
+ def testBatchedRequest(): Unit = {
Review Comment:
nit: Could we include the related api in the name?
##########
core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala:
##########
@@ -317,6 +322,34 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
}
}
}
+
+ def handleVerifyPartitionsInTransaction(transactionalId: String,
+ producerId: Long,
+ producerEpoch: Short,
+ partitions:
collection.Set[TopicPartition],
+ responseCallback:
VerifyPartitionsCallback): Unit = {
+ if (transactionalId == null || transactionalId.isEmpty) {
+ debug(s"Returning ${Errors.INVALID_REQUEST} error code to client for
$transactionalId's AddPartitions request")
+
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
partitions.map(_ -> Errors.INVALID_REQUEST).toMap.asJava))
+ } else {
+ val result: ApiResult[(Int, TransactionMetadata)] =
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
+
+ result match {
+ case Left(err) =>
+ debug(s"Returning $err error code to client for $transactionalId's
AddPartitions request")
+
responseCallback(AddPartitionsToTxnResponse.resultForTransaction(transactionalId,
partitions.map(_ -> err).toMap.asJava))
+
+ case Right((_, txnMetadata)) =>
+ val txnMetadataPartitions = txnMetadata.topicPartitions
+ val addedPartitions = partitions.intersect(txnMetadataPartitions)
+ val nonAddedPartitions = partitions.diff(txnMetadataPartitions)
+ val errors = mutable.Map[TopicPartition, Errors]()
+ addedPartitions.foreach(errors.put(_, Errors.NONE))
+ nonAddedPartitions.foreach(errors.put(_, Errors.INVALID_TXN_STATE))
Review Comment:
nit: I am not sure if it makes a real difference but did you consider doing
something like this:
```
partitions.foreach { tp =>
if (txnMetadata.topicPartitions.contains(tp))
...
else
...
}
```
If works, it would avoid allocating the intermediate collections. I leave
this up to you.
##########
clients/src/test/java/org/apache/kafka/common/requests/RequestResponseTest.java:
##########
@@ -2598,12 +2604,37 @@ private OffsetsForLeaderEpochResponse
createLeaderEpochResponse() {
}
private AddPartitionsToTxnRequest createAddPartitionsToTxnRequest(short
version) {
- return new AddPartitionsToTxnRequest.Builder("tid", 21L, (short) 42,
- singletonList(new TopicPartition("topic", 73))).build(version);
+ if (version < 4) {
+ return AddPartitionsToTxnRequest.Builder.forClient("tid", 21L,
(short) 42,
+ singletonList(new TopicPartition("topic",
73))).build(version);
+ } else {
+ AddPartitionsToTxnTransactionCollection transactions = new
AddPartitionsToTxnTransactionCollection(
+ singletonList(new AddPartitionsToTxnTransaction()
Review Comment:
nit: Indentation of this line looks weird. Should it be on the previous line?
##########
core/src/main/scala/kafka/server/KafkaApis.scala:
##########
@@ -2384,66 +2386,111 @@ class KafkaApis(val requestChannel: RequestChannel,
if (config.interBrokerProtocolVersion.isLessThan(version))
throw new UnsupportedVersionException(s"inter.broker.protocol.version:
${config.interBrokerProtocolVersion.version} is less than the required version:
${version.version}")
}
-
- def handleAddPartitionToTxnRequest(request: RequestChannel.Request,
requestLocal: RequestLocal): Unit = {
+ def handleAddPartitionsToTxnRequest(request: RequestChannel.Request,
requestLocal: RequestLocal): Unit = {
ensureInterBrokerVersion(IBP_0_11_0_IV0)
- val addPartitionsToTxnRequest = request.body[AddPartitionsToTxnRequest]
- val transactionalId = addPartitionsToTxnRequest.data.transactionalId
- val partitionsToAdd = addPartitionsToTxnRequest.partitions.asScala
- if (!authHelper.authorize(request.context, WRITE, TRANSACTIONAL_ID,
transactionalId))
- requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
- addPartitionsToTxnRequest.getErrorResponse(requestThrottleMs,
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception))
- else {
- val unauthorizedTopicErrors = mutable.Map[TopicPartition, Errors]()
- val nonExistingTopicErrors = mutable.Map[TopicPartition, Errors]()
- val authorizedPartitions = mutable.Set[TopicPartition]()
-
- val authorizedTopics = authHelper.filterByAuthorized(request.context,
WRITE, TOPIC,
- partitionsToAdd.filterNot(tp => Topic.isInternal(tp.topic)))(_.topic)
- for (topicPartition <- partitionsToAdd) {
- if (!authorizedTopics.contains(topicPartition.topic))
- unauthorizedTopicErrors += topicPartition ->
Errors.TOPIC_AUTHORIZATION_FAILED
- else if (!metadataCache.contains(topicPartition))
- nonExistingTopicErrors += topicPartition ->
Errors.UNKNOWN_TOPIC_OR_PARTITION
- else
- authorizedPartitions.add(topicPartition)
+ val addPartitionsToTxnRequest =
+ if (request.context.apiVersion() < 4)
+ request.body[AddPartitionsToTxnRequest].normalizeRequest()
+ else
+ request.body[AddPartitionsToTxnRequest]
+ val version = addPartitionsToTxnRequest.version
+ val responses = new AddPartitionsToTxnResultCollection()
+ val partitionsByTransaction =
addPartitionsToTxnRequest.partitionsByTransaction()
+
+ // Newer versions of the request should only come from other brokers.
+ if (version >= 4) authHelper.authorizeClusterOperation(request,
CLUSTER_ACTION)
+
+ // V4 requests introduced batches of transactions. We need all
transactions to be handled before sending the
+ // response so there are a few differences in handling errors and sending
responses.
+ def createResponse(requestThrottleMs: Int): AbstractResponse = {
+ if (version < 4) {
+ // There will only be one response in data. Add it to the response
data object.
+ val data = new AddPartitionsToTxnResponseData()
+ responses.forEach { result =>
+ data.setResultsByTopicV3AndBelow(result.topicResults())
+ data.setThrottleTimeMs(requestThrottleMs)
+ }
+ new AddPartitionsToTxnResponse(data)
+ } else {
+ new AddPartitionsToTxnResponse(new
AddPartitionsToTxnResponseData().setThrottleTimeMs(requestThrottleMs).setResultsByTransaction(responses))
}
+ }
- if (unauthorizedTopicErrors.nonEmpty || nonExistingTopicErrors.nonEmpty)
{
- // Any failed partition check causes the entire request to fail. We
send the appropriate error codes for the
- // partitions which failed, and an 'OPERATION_NOT_ATTEMPTED' error
code for the partitions which succeeded
- // the authorization check to indicate that they were not added to the
transaction.
- val partitionErrors = unauthorizedTopicErrors ++
nonExistingTopicErrors ++
- authorizedPartitions.map(_ -> Errors.OPERATION_NOT_ATTEMPTED)
- requestHelper.sendResponseMaybeThrottle(request, requestThrottleMs =>
- new AddPartitionsToTxnResponse(requestThrottleMs,
partitionErrors.asJava))
+ val txns = addPartitionsToTxnRequest.data.transactions
+ def addResultAndMaybeSendResponse(result: AddPartitionsToTxnResult): Unit
= {
+ val canSend = responses.synchronized {
+ responses.add(result)
+ responses.size() == txns.size()
Review Comment:
nit: You can remove the `()`.
##########
core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala:
##########
@@ -330,44 +363,53 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} else {
// try to update the transaction metadata and append the updated
metadata to txn log;
// if there is no such metadata treat it as invalid producerId mapping
error.
- val result: ApiResult[(Int, TxnTransitMetadata)] =
txnManager.getTransactionState(transactionalId).flatMap {
- case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING)
-
- case Some(epochAndMetadata) =>
- val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
- val txnMetadata = epochAndMetadata.transactionMetadata
-
- // generate the new transaction metadata with added partitions
- txnMetadata.inLock {
- if (txnMetadata.producerId != producerId) {
- Left(Errors.INVALID_PRODUCER_ID_MAPPING)
- } else if (txnMetadata.producerEpoch != producerEpoch) {
- Left(Errors.PRODUCER_FENCED)
- } else if (txnMetadata.pendingTransitionInProgress) {
- // return a retriable exception to let the client backoff and
retry
- Left(Errors.CONCURRENT_TRANSACTIONS)
- } else if (txnMetadata.state == PrepareCommit || txnMetadata.state
== PrepareAbort) {
- Left(Errors.CONCURRENT_TRANSACTIONS)
- } else if (txnMetadata.state == Ongoing &&
partitions.subsetOf(txnMetadata.topicPartitions)) {
- // this is an optimization: if the partitions are already in the
metadata reply OK immediately
- Left(Errors.NONE)
- } else {
- Right(coordinatorEpoch,
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()))
- }
- }
- }
+ val result: ApiResult[(Int, TransactionMetadata)] =
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
result match {
case Left(err) =>
debug(s"Returning $err error code to client for $transactionalId's
AddPartitions request")
responseCallback(err)
- case Right((coordinatorEpoch, newMetadata)) =>
- txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch,
newMetadata,
+ case Right((coordinatorEpoch, txnMetadata)) =>
+ txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch,
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()),
responseCallback, requestLocal = requestLocal)
}
}
}
+
+ private def getTransactionMetadata(transactionalId: String,
+ producerId: Long,
+ producerEpoch: Short,
+ partitions: collection.Set[TopicPartition]):
ApiResult[(Int, TransactionMetadata)] = {
Review Comment:
nit: Indentation is off.
##########
core/src/main/scala/kafka/coordinator/transaction/TransactionCoordinator.scala:
##########
@@ -330,44 +363,53 @@ class TransactionCoordinator(txnConfig: TransactionConfig,
} else {
// try to update the transaction metadata and append the updated
metadata to txn log;
// if there is no such metadata treat it as invalid producerId mapping
error.
- val result: ApiResult[(Int, TxnTransitMetadata)] =
txnManager.getTransactionState(transactionalId).flatMap {
- case None => Left(Errors.INVALID_PRODUCER_ID_MAPPING)
-
- case Some(epochAndMetadata) =>
- val coordinatorEpoch = epochAndMetadata.coordinatorEpoch
- val txnMetadata = epochAndMetadata.transactionMetadata
-
- // generate the new transaction metadata with added partitions
- txnMetadata.inLock {
- if (txnMetadata.producerId != producerId) {
- Left(Errors.INVALID_PRODUCER_ID_MAPPING)
- } else if (txnMetadata.producerEpoch != producerEpoch) {
- Left(Errors.PRODUCER_FENCED)
- } else if (txnMetadata.pendingTransitionInProgress) {
- // return a retriable exception to let the client backoff and
retry
- Left(Errors.CONCURRENT_TRANSACTIONS)
- } else if (txnMetadata.state == PrepareCommit || txnMetadata.state
== PrepareAbort) {
- Left(Errors.CONCURRENT_TRANSACTIONS)
- } else if (txnMetadata.state == Ongoing &&
partitions.subsetOf(txnMetadata.topicPartitions)) {
- // this is an optimization: if the partitions are already in the
metadata reply OK immediately
- Left(Errors.NONE)
- } else {
- Right(coordinatorEpoch,
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()))
- }
- }
- }
+ val result: ApiResult[(Int, TransactionMetadata)] =
getTransactionMetadata(transactionalId, producerId, producerEpoch, partitions)
result match {
case Left(err) =>
debug(s"Returning $err error code to client for $transactionalId's
AddPartitions request")
responseCallback(err)
- case Right((coordinatorEpoch, newMetadata)) =>
- txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch,
newMetadata,
+ case Right((coordinatorEpoch, txnMetadata)) =>
+ txnManager.appendTransactionToLog(transactionalId, coordinatorEpoch,
txnMetadata.prepareAddPartitions(partitions.toSet, time.milliseconds()),
responseCallback, requestLocal = requestLocal)
}
}
}
+
+ private def getTransactionMetadata(transactionalId: String,
+ producerId: Long,
+ producerEpoch: Short,
+ partitions: collection.Set[TopicPartition]):
ApiResult[(Int, TransactionMetadata)] = {
+ // try to update the transaction metadata and append the updated metadata
to txn log;
+ // if there is no such metadata treat it as invalid producerId mapping
error.
Review Comment:
nit: Is this comment relevant here?
##########
core/src/test/scala/unit/kafka/server/AddPartitionsToTxnRequestServerTest.scala:
##########
@@ -55,22 +65,146 @@ class AddPartitionsToTxnRequestServerTest extends
BaseRequestTest {
val producerId = 1000L
val producerEpoch: Short = 0
- val request = new AddPartitionsToTxnRequest.Builder(
- transactionalId,
- producerId,
- producerEpoch,
- List(createdTopicPartition, nonExistentTopic).asJava)
- .build()
+ val request =
+ if (version < 4) {
+ AddPartitionsToTxnRequest.Builder.forClient(
+ transactionalId,
+ producerId,
+ producerEpoch,
+ List(createdTopicPartition, nonExistentTopic).asJava
+ ).build()
Review Comment:
Should we set the version here?
##########
core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorTest.scala:
##########
@@ -1185,4 +1213,8 @@ class TransactionCoordinatorTest {
def errorsCallback(ret: Errors): Unit = {
error = ret
}
+
+ def verifyPartitionsInTxnCallback(result: AddPartitionsToTxnResult): Unit = {
+ errors =
AddPartitionsToTxnResponse.errorsForTransaction(result.topicResults()).asScala.toMap
Review Comment:
Relying on a global variable is risky here. It would be much better to
define the callback within the test itself.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]