dajac commented on code in PR #13502:
URL: https://github.com/apache/kafka/pull/13502#discussion_r1164508565
##
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##
@@ -2113,6 +2113,115 @@ class KafkaApisTest {
assertEquals(expectedErrors, response.errors())
}
+ @ParameterizedTest
+ @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
+ def testHandleAddPartitionsToTxnAuthorizationFailed(version: Short): Unit = {
+val topic = "topic"
+
+val transactionalId = "txnId1"
+val producerId = 15L
+val epoch = 0.toShort
+
+val tp = new TopicPartition(topic, 0)
+
+val addPartitionsToTxnRequest = if (version < 4)
+ AddPartitionsToTxnRequest.Builder.forClient(
+transactionalId,
+producerId,
+epoch,
+Collections.singletonList(tp)).build(version)
+else
+ AddPartitionsToTxnRequest.Builder.forBroker(
+new AddPartitionsToTxnTransactionCollection(
Review Comment:
nit: indentation seems to be off.
##
core/src/test/scala/unit/kafka/server/KafkaApisTest.scala:
##
@@ -2113,6 +2113,115 @@ class KafkaApisTest {
assertEquals(expectedErrors, response.errors())
}
+ @ParameterizedTest
+ @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
+ def testHandleAddPartitionsToTxnAuthorizationFailed(version: Short): Unit = {
+val topic = "topic"
+
+val transactionalId = "txnId1"
+val producerId = 15L
+val epoch = 0.toShort
+
+val tp = new TopicPartition(topic, 0)
+
+val addPartitionsToTxnRequest = if (version < 4)
+ AddPartitionsToTxnRequest.Builder.forClient(
+transactionalId,
+producerId,
+epoch,
+Collections.singletonList(tp)).build(version)
+else
+ AddPartitionsToTxnRequest.Builder.forBroker(
+new AddPartitionsToTxnTransactionCollection(
+List(new AddPartitionsToTxnTransaction()
+ .setTransactionalId(transactionalId)
+ .setProducerId(producerId)
+ .setProducerEpoch(epoch)
+ .setVerifyOnly(true)
+ .setTopics(new AddPartitionsToTxnTopicCollection(
+Collections.singletonList(new AddPartitionsToTxnTopic()
+ .setName(tp.topic)
+ .setPartitions(Collections.singletonList(tp.partition))
+).iterator()))
+).asJava.iterator())).build(version)
+
+val requestChannelRequest = buildRequest(addPartitionsToTxnRequest)
+
+val authorizer: Authorizer = mock(classOf[Authorizer])
+when(authorizer.authorize(any[RequestContext], any[util.List[Action]]))
+ .thenReturn(Seq(AuthorizationResult.DENIED).asJava)
+
+ createKafkaApis(authorizer = Some(authorizer)).handle(
+requestChannelRequest,
+RequestLocal.NoCaching
+ )
+
+val response =
verifyNoThrottling[AddPartitionsToTxnResponse](requestChannelRequest)
+val error = if (version < 4)
+
response.errors().get(AddPartitionsToTxnResponse.V3_AND_BELOW_TXN_ID).get(tp)
+else
+ Errors.forCode(response.data().errorCode())
+
+val expectedError = if (version < 4)
Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED else
Errors.CLUSTER_AUTHORIZATION_FAILED
+assertEquals(expectedError, error)
+ }
+
+ @ParameterizedTest
+ @ApiKeyVersionsSource(apiKey = ApiKeys.ADD_PARTITIONS_TO_TXN)
+ def testAddPartitionsToTxnOperationNotAttempted(version: Short): Unit = {
+val topic = "topic"
+addTopicToMetadataCache(topic, numPartitions = 1)
+
+val transactionalId = "txnId1"
+val producerId = 15L
+val epoch = 0.toShort
+
+val tp0 = new TopicPartition(topic, 0)
+val tp1 = new TopicPartition(topic, 1)
+
+val addPartitionsToTxnRequest = if (version < 4)
+ AddPartitionsToTxnRequest.Builder.forClient(
+transactionalId,
+producerId,
+epoch,
+List(tp0, tp1).asJava).build(version)
+else
+ AddPartitionsToTxnRequest.Builder.forBroker(
+new AddPartitionsToTxnTransactionCollection(
+ List(new AddPartitionsToTxnTransaction()
+.setTransactionalId(transactionalId)
+.setProducerId(producerId)
+.setProducerEpoch(epoch)
+.setVerifyOnly(true)
+.setTopics(new AddPartitionsToTxnTopicCollection(
+ Collections.singletonList(new AddPartitionsToTxnTopic()
+.setName(tp0.topic)
+.setPartitions(List[Integer](tp0.partition,
tp1.partition()).asJava)
+ ).iterator()))
+ ).asJava.iterator())).build(version)
+
+val requestChannelRequest = buildRequest(addPartitionsToTxnRequest)
+
+createKafkaApis().handleAddPartitionsToTxnRequest(
+ requestChannelRequest,
+ RequestLocal.NoCaching
+)
+
+val response =
verifyNoThrottling[AddPartitionsToTxnResponse](requestChannelRequest)
+
+def checkErrorForTp(tp: TopicPartition): Unit = {
+ val error = if (version < 4)
+