jeffkbkim commented on code in PR #12845: URL: https://github.com/apache/kafka/pull/12845#discussion_r1025980136
########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -161,6 +166,12 @@ class KafkaApis(val requestChannel: RequestChannel, * Top-level method that handles all requests and multiplexes to the right api */ override def handle(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { + def handleError(e: Throwable): Unit = { + trace(s"Unexpected error handling request ${request.requestDesc(true)} " + Review Comment: it looks like this message was an error level log and this change affects all other apis. what's the reason for changing it to trace? ########## core/src/test/scala/unit/kafka/server/KafkaApisTest.scala: ########## @@ -2524,196 +2530,208 @@ class KafkaApisTest { assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData)) } - @Test - def testJoinGroupProtocolsOrder(): Unit = { - val protocols = List( - ("first", "first".getBytes()), - ("second", "second".getBytes()) + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP) + def testHandleJoinGroupRequest(version: Short): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build(version)) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + version, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier ) - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 - val capturedProtocols: ArgumentCaptor[List[(String, Array[Byte])]] = ArgumentCaptor.forClass(classOf[List[(String, Array[Byte])]]) + val expectedJoinGroupRequest = new JoinGroupRequestData() + .setGroupId(joinGroupRequest.groupId) + .setMemberId(joinGroupRequest.memberId) + .setProtocolType(joinGroupRequest.protocolType) + .setRebalanceTimeoutMs(if (version >= 1) joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs) + .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs) - createKafkaApis().handleJoinGroupRequest( - buildRequest( - new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - .setProtocols(new JoinGroupRequestData.JoinGroupRequestProtocolCollection( - protocols.map { case (name, protocol) => new JoinGroupRequestProtocol() - .setName(name).setMetadata(protocol) - }.iterator.asJava)) - ).build() - ), - RequestLocal.withThreadConfinedCaching) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(expectedJoinGroupRequest) + )).thenReturn(future) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(rebalanceTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - capturedProtocols.capture(), - any(), - any(), - any() + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching ) - val capturedProtocolsList = capturedProtocols.getValue - assertEquals(protocols.size, capturedProtocolsList.size) - protocols.zip(capturedProtocolsList).foreach { case ((expectedName, expectedBytes), (name, bytes)) => - assertEquals(expectedName, name) - assertArrayEquals(expectedBytes, bytes) - } - } - @Test - def testJoinGroupWhenAnErrorOccurs(): Unit = { - for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { - testJoinGroupWhenAnErrorOccurs(version.asInstanceOf[Short]) - } - } + val expectedJoinGroupResponse = new JoinGroupResponseData() + .setMemberId("member") + .setGenerationId(0) + .setLeader("leader") + .setProtocolType("consumer") + .setProtocolName("range") - def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = { - reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + future.complete(expectedJoinGroupResponse) + val capturedResponse = verifyNoThrottling(requestChannelRequest) + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(expectedJoinGroupResponse, response.data) + } - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP) + def testJoinGroupProtocolNameBackwardCompatibility(version: Short): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build(version)) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + version, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) - val capturedCallback: ArgumentCaptor[JoinGroupCallback] = ArgumentCaptor.forClass(classOf[JoinGroupCallback]) + val expectedJoinGroupRequest = new JoinGroupRequestData() + .setGroupId(joinGroupRequest.groupId) + .setMemberId(joinGroupRequest.memberId) + .setProtocolType(joinGroupRequest.protocolType) + .setRebalanceTimeoutMs(if (version >= 1) joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs) + .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs) - val joinGroupRequest = new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - ).build(version) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(expectedJoinGroupRequest) + )).thenReturn(future) - val requestChannelRequest = buildRequest(joinGroupRequest) + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching + ) - createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + val joinGroupResponse = new JoinGroupResponseData() + .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code) + .setMemberId("member") + .setProtocolName(null) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(if (version >= 4) true else false), - ArgumentMatchers.eq(if (version >= 9) true else false), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - ArgumentMatchers.eq(List.empty), - capturedCallback.capture(), - any(), - any() - ) - capturedCallback.getValue.apply(JoinGroupResult(memberId, Errors.INCONSISTENT_GROUP_PROTOCOL)) + val expectedJoinGroupResponse = new JoinGroupResponseData() + .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code) + .setMemberId("member") + .setProtocolName(if (version >= 7) null else GroupCoordinator.NoProtocol) + future.complete(joinGroupResponse) val capturedResponse = verifyNoThrottling(requestChannelRequest) val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] - - assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error) - assertEquals(0, response.data.members.size) - assertEquals(memberId, response.data.memberId) - assertEquals(GroupCoordinator.NoGeneration, response.data.generationId) - assertEquals(GroupCoordinator.NoLeader, response.data.leader) - assertNull(response.data.protocolType) - - if (version >= 7) { - assertNull(response.data.protocolName) - } else { - assertEquals(GroupCoordinator.NoProtocol, response.data.protocolName) - } + assertEquals(expectedJoinGroupResponse, response.data) } @Test - def testJoinGroupProtocolType(): Unit = { - for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { - testJoinGroupProtocolType(version.asInstanceOf[Short]) - } - } + def testHandleJoinGroupRequestFutureFailed(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) - def testJoinGroupProtocolType(version: Short): Unit = { - reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val protocolName = "range" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 + val expectedRequestContext = new GroupCoordinatorRequestContext( + ApiKeys.JOIN_GROUP.latestVersion, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) - val capturedCallback: ArgumentCaptor[JoinGroupCallback] = ArgumentCaptor.forClass(classOf[JoinGroupCallback]) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(joinGroupRequest) + )).thenReturn(future) - val joinGroupRequest = new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - ).build(version) + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching + ) - val requestChannelRequest = buildRequest(joinGroupRequest) + future.completeExceptionally(Errors.REQUEST_TIMED_OUT.exception) + val capturedResponse = verifyNoThrottling(requestChannelRequest) + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(Errors.REQUEST_TIMED_OUT, response.error) + } - createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + @Test + def testHandleJoinGroupRequestAuthorizationFailed(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(if (version >= 4) true else false), - ArgumentMatchers.eq(if (version >= 9) true else false), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - ArgumentMatchers.eq(List.empty), - capturedCallback.capture(), - any(), - any() + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) + + val authorizer: Authorizer = mock(classOf[Authorizer]) + when(authorizer.authorize(any[RequestContext], any[util.List[Action]])) + .thenReturn(Seq(AuthorizationResult.DENIED).asJava) + + createKafkaApis(authorizer = Some(authorizer)).handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching ) - capturedCallback.getValue.apply(JoinGroupResult( - members = List.empty, - memberId = memberId, - generationId = 0, - protocolType = Some(protocolType), - protocolName = Some(protocolName), - leaderId = memberId, - skipAssignment = true, - error = Errors.NONE - )) + val capturedResponse = verifyNoThrottling(requestChannelRequest) val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, response.error) + } - assertEquals(Errors.NONE, response.error) - assertEquals(0, response.data.members.size) - assertEquals(memberId, response.data.memberId) - assertEquals(0, response.data.generationId) - assertEquals(memberId, response.data.leader) - assertEquals(protocolName, response.data.protocolName) - assertEquals(protocolType, response.data.protocolType) - assertTrue(response.data.skipAssignment) + @Test + def testHandleJoinGroupRequestUnexpectedException(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + ApiKeys.JOIN_GROUP.latestVersion, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) + + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(joinGroupRequest) + )).thenReturn(future) + + val response = new AtomicReference[JoinGroupResponse]() Review Comment: can you help me understand the reason for using an atomic reference? ########## core/src/main/scala/kafka/server/KafkaApis.scala: ########## @@ -1647,69 +1656,51 @@ class KafkaApis(val requestChannel: RequestChannel, } } - def handleJoinGroupRequest(request: RequestChannel.Request, requestLocal: RequestLocal): Unit = { - val joinGroupRequest = request.body[JoinGroupRequest] + private def makeGroupCoordinatorRequestContext( + request: RequestChannel.Request, + requestLocal: RequestLocal + ): GroupCoordinatorRequestContext = { + new GroupCoordinatorRequestContext( + request.context.header.data.requestApiVersion, + request.context.header.data.clientId, + request.context.clientAddress, + requestLocal.bufferSupplier + ) + } - // the callback for sending a join-group response - def sendResponseCallback(joinResult: JoinGroupResult): Unit = { - def createResponse(requestThrottleMs: Int): AbstractResponse = { - val responseBody = new JoinGroupResponse( - new JoinGroupResponseData() - .setThrottleTimeMs(requestThrottleMs) - .setErrorCode(joinResult.error.code) - .setGenerationId(joinResult.generationId) - .setProtocolType(joinResult.protocolType.orNull) - .setProtocolName(joinResult.protocolName.orNull) - .setLeader(joinResult.leaderId) - .setSkipAssignment(joinResult.skipAssignment) - .setMemberId(joinResult.memberId) - .setMembers(joinResult.members.asJava), - request.context.apiVersion - ) + def handleJoinGroupRequest( + request: RequestChannel.Request, + requestLocal: RequestLocal + ): CompletableFuture[Unit] = { Review Comment: some questions for my understanding: 1. it looks like we handle exceptions thrown by the new joinGroup in `newGroupCoordinator.joinGroup(ctx, joinGroupRequest.data).handle[Unit]`. are we returning a future here to handle exceptions thrown during `sendResponse()`? it looks to me that would be handled by `case e: Throwable => handleError(e)` or am i missing something? 3. would `GroupCoordinatorAdapter.joinGroup()` work without returning a CompletableFuture object? 4. are we returning CompletableFuture in GroupCoordinatorAdapter to prepare it for the new group coordinator since it will use multiple threads to handle a single join/sync group request? ########## core/src/test/scala/unit/kafka/server/KafkaApisTest.scala: ########## @@ -2524,196 +2530,208 @@ class KafkaApisTest { assertEquals(MemoryRecords.EMPTY, FetchResponse.recordsOrFail(partitionData)) } - @Test - def testJoinGroupProtocolsOrder(): Unit = { - val protocols = List( - ("first", "first".getBytes()), - ("second", "second".getBytes()) + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP) + def testHandleJoinGroupRequest(version: Short): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build(version)) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + version, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier ) - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 - val capturedProtocols: ArgumentCaptor[List[(String, Array[Byte])]] = ArgumentCaptor.forClass(classOf[List[(String, Array[Byte])]]) + val expectedJoinGroupRequest = new JoinGroupRequestData() + .setGroupId(joinGroupRequest.groupId) + .setMemberId(joinGroupRequest.memberId) + .setProtocolType(joinGroupRequest.protocolType) + .setRebalanceTimeoutMs(if (version >= 1) joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs) + .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs) - createKafkaApis().handleJoinGroupRequest( - buildRequest( - new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - .setProtocols(new JoinGroupRequestData.JoinGroupRequestProtocolCollection( - protocols.map { case (name, protocol) => new JoinGroupRequestProtocol() - .setName(name).setMetadata(protocol) - }.iterator.asJava)) - ).build() - ), - RequestLocal.withThreadConfinedCaching) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(expectedJoinGroupRequest) + )).thenReturn(future) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(true), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(rebalanceTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - capturedProtocols.capture(), - any(), - any(), - any() + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching ) - val capturedProtocolsList = capturedProtocols.getValue - assertEquals(protocols.size, capturedProtocolsList.size) - protocols.zip(capturedProtocolsList).foreach { case ((expectedName, expectedBytes), (name, bytes)) => - assertEquals(expectedName, name) - assertArrayEquals(expectedBytes, bytes) - } - } - @Test - def testJoinGroupWhenAnErrorOccurs(): Unit = { - for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { - testJoinGroupWhenAnErrorOccurs(version.asInstanceOf[Short]) - } - } + val expectedJoinGroupResponse = new JoinGroupResponseData() + .setMemberId("member") + .setGenerationId(0) + .setLeader("leader") + .setProtocolType("consumer") + .setProtocolName("range") - def testJoinGroupWhenAnErrorOccurs(version: Short): Unit = { - reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + future.complete(expectedJoinGroupResponse) + val capturedResponse = verifyNoThrottling(requestChannelRequest) + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(expectedJoinGroupResponse, response.data) + } - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 + @ParameterizedTest + @ApiKeyVersionsSource(apiKey = ApiKeys.JOIN_GROUP) + def testJoinGroupProtocolNameBackwardCompatibility(version: Short): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build(version)) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + version, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) - val capturedCallback: ArgumentCaptor[JoinGroupCallback] = ArgumentCaptor.forClass(classOf[JoinGroupCallback]) + val expectedJoinGroupRequest = new JoinGroupRequestData() + .setGroupId(joinGroupRequest.groupId) + .setMemberId(joinGroupRequest.memberId) + .setProtocolType(joinGroupRequest.protocolType) + .setRebalanceTimeoutMs(if (version >= 1) joinGroupRequest.rebalanceTimeoutMs else joinGroupRequest.sessionTimeoutMs) + .setSessionTimeoutMs(joinGroupRequest.sessionTimeoutMs) - val joinGroupRequest = new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - ).build(version) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(expectedJoinGroupRequest) + )).thenReturn(future) - val requestChannelRequest = buildRequest(joinGroupRequest) + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching + ) - createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + val joinGroupResponse = new JoinGroupResponseData() + .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code) + .setMemberId("member") + .setProtocolName(null) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(if (version >= 4) true else false), - ArgumentMatchers.eq(if (version >= 9) true else false), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - ArgumentMatchers.eq(List.empty), - capturedCallback.capture(), - any(), - any() - ) - capturedCallback.getValue.apply(JoinGroupResult(memberId, Errors.INCONSISTENT_GROUP_PROTOCOL)) + val expectedJoinGroupResponse = new JoinGroupResponseData() + .setErrorCode(Errors.INCONSISTENT_GROUP_PROTOCOL.code) + .setMemberId("member") + .setProtocolName(if (version >= 7) null else GroupCoordinator.NoProtocol) + future.complete(joinGroupResponse) val capturedResponse = verifyNoThrottling(requestChannelRequest) val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] - - assertEquals(Errors.INCONSISTENT_GROUP_PROTOCOL, response.error) - assertEquals(0, response.data.members.size) - assertEquals(memberId, response.data.memberId) - assertEquals(GroupCoordinator.NoGeneration, response.data.generationId) - assertEquals(GroupCoordinator.NoLeader, response.data.leader) - assertNull(response.data.protocolType) - - if (version >= 7) { - assertNull(response.data.protocolName) - } else { - assertEquals(GroupCoordinator.NoProtocol, response.data.protocolName) - } + assertEquals(expectedJoinGroupResponse, response.data) } @Test - def testJoinGroupProtocolType(): Unit = { - for (version <- ApiKeys.JOIN_GROUP.oldestVersion to ApiKeys.JOIN_GROUP.latestVersion) { - testJoinGroupProtocolType(version.asInstanceOf[Short]) - } - } + def testHandleJoinGroupRequestFutureFailed(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) - def testJoinGroupProtocolType(version: Short): Unit = { - reset(groupCoordinator, clientRequestQuotaManager, requestChannel, replicaManager) + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) - val groupId = "group" - val memberId = "member1" - val protocolType = "consumer" - val protocolName = "range" - val rebalanceTimeoutMs = 10 - val sessionTimeoutMs = 5 + val expectedRequestContext = new GroupCoordinatorRequestContext( + ApiKeys.JOIN_GROUP.latestVersion, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) - val capturedCallback: ArgumentCaptor[JoinGroupCallback] = ArgumentCaptor.forClass(classOf[JoinGroupCallback]) + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(joinGroupRequest) + )).thenReturn(future) - val joinGroupRequest = new JoinGroupRequest.Builder( - new JoinGroupRequestData() - .setGroupId(groupId) - .setMemberId(memberId) - .setProtocolType(protocolType) - .setRebalanceTimeoutMs(rebalanceTimeoutMs) - .setSessionTimeoutMs(sessionTimeoutMs) - ).build(version) + createKafkaApis().handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching + ) - val requestChannelRequest = buildRequest(joinGroupRequest) + future.completeExceptionally(Errors.REQUEST_TIMED_OUT.exception) + val capturedResponse = verifyNoThrottling(requestChannelRequest) + val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(Errors.REQUEST_TIMED_OUT, response.error) + } - createKafkaApis().handleJoinGroupRequest(requestChannelRequest, RequestLocal.withThreadConfinedCaching) + @Test + def testHandleJoinGroupRequestAuthorizationFailed(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) - verify(groupCoordinator).handleJoinGroup( - ArgumentMatchers.eq(groupId), - ArgumentMatchers.eq(memberId), - ArgumentMatchers.eq(None), - ArgumentMatchers.eq(if (version >= 4) true else false), - ArgumentMatchers.eq(if (version >= 9) true else false), - ArgumentMatchers.eq(clientId), - ArgumentMatchers.eq(InetAddress.getLocalHost.toString), - ArgumentMatchers.eq(if (version >= 1) rebalanceTimeoutMs else sessionTimeoutMs), - ArgumentMatchers.eq(sessionTimeoutMs), - ArgumentMatchers.eq(protocolType), - ArgumentMatchers.eq(List.empty), - capturedCallback.capture(), - any(), - any() + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) + + val authorizer: Authorizer = mock(classOf[Authorizer]) + when(authorizer.authorize(any[RequestContext], any[util.List[Action]])) + .thenReturn(Seq(AuthorizationResult.DENIED).asJava) + + createKafkaApis(authorizer = Some(authorizer)).handleJoinGroupRequest( + requestChannelRequest, + RequestLocal.NoCaching ) - capturedCallback.getValue.apply(JoinGroupResult( - members = List.empty, - memberId = memberId, - generationId = 0, - protocolType = Some(protocolType), - protocolName = Some(protocolName), - leaderId = memberId, - skipAssignment = true, - error = Errors.NONE - )) + val capturedResponse = verifyNoThrottling(requestChannelRequest) val response = capturedResponse.getValue.asInstanceOf[JoinGroupResponse] + assertEquals(Errors.GROUP_AUTHORIZATION_FAILED, response.error) + } - assertEquals(Errors.NONE, response.error) - assertEquals(0, response.data.members.size) - assertEquals(memberId, response.data.memberId) - assertEquals(0, response.data.generationId) - assertEquals(memberId, response.data.leader) - assertEquals(protocolName, response.data.protocolName) - assertEquals(protocolType, response.data.protocolType) - assertTrue(response.data.skipAssignment) + @Test + def testHandleJoinGroupRequestUnexpectedException(): Unit = { + val joinGroupRequest = new JoinGroupRequestData() + .setGroupId("group") + .setMemberId("member") + .setProtocolType("consumer") + .setRebalanceTimeoutMs(1000) + .setSessionTimeoutMs(2000) + + val requestChannelRequest = buildRequest(new JoinGroupRequest.Builder(joinGroupRequest).build()) + + val expectedRequestContext = new GroupCoordinatorRequestContext( + ApiKeys.JOIN_GROUP.latestVersion, + requestChannelRequest.context.clientId, + requestChannelRequest.context.clientAddress, + RequestLocal.NoCaching.bufferSupplier + ) + + val future = new CompletableFuture[JoinGroupResponseData]() + when(newGroupCoordinator.joinGroup( + ArgumentMatchers.eq(expectedRequestContext), + ArgumentMatchers.eq(joinGroupRequest) + )).thenReturn(future) + + val response = new AtomicReference[JoinGroupResponse]() + when(requestChannel.sendResponse(any(), any(), any())).thenAnswer { _ => + throw new Exception("Something went wrong") + }.thenAnswer { invocation => + val resp = invocation.getArgument(1, classOf[JoinGroupResponse]) + response.set(resp) + } + + createKafkaApis().handle( + requestChannelRequest, + RequestLocal.NoCaching + ) Review Comment: i'm surely missing something here - shouldn't the test be blocked here? i can't seem to find the other thread running -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org