KAFKA-3489; Update request metrics if a client closes a connection while the broker response is in flight
I also fixed a few issues in `SocketServerTest` and included a few clean-ups. Author: Ismael Juma <[email protected]> Reviewers: Jun Rao <[email protected]> Closes #1172 from ijuma/kafka-3489-update-request-metrics-if-client-closes Project: http://git-wip-us.apache.org/repos/asf/kafka/repo Commit: http://git-wip-us.apache.org/repos/asf/kafka/commit/e733d8c2 Tree: http://git-wip-us.apache.org/repos/asf/kafka/tree/e733d8c2 Diff: http://git-wip-us.apache.org/repos/asf/kafka/diff/e733d8c2 Branch: refs/heads/0.10.0 Commit: e733d8c2fbcee19ee77c436e66abb29850a2f7c2 Parents: 7030148 Author: Ismael Juma <[email protected]> Authored: Tue Apr 5 18:16:48 2016 -0400 Committer: Gwen Shapira <[email protected]> Committed: Tue Apr 5 17:08:53 2016 -0700 ---------------------------------------------------------------------- .../apache/kafka/common/network/Selector.java | 6 +- .../scala/kafka/network/RequestChannel.scala | 53 +++--- .../main/scala/kafka/network/SocketServer.scala | 185 +++++++++++-------- .../unit/kafka/network/SocketServerTest.scala | 141 +++++++++++--- 4 files changed, 257 insertions(+), 128 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/kafka/blob/e733d8c2/clients/src/main/java/org/apache/kafka/common/network/Selector.java ---------------------------------------------------------------------- diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 698b99c..c333741 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -491,7 +491,7 @@ public class Selector implements Selectable { private KafkaChannel channelOrFail(String id) { KafkaChannel channel = this.channels.get(id); if (channel == null) - throw new IllegalStateException("Attempt to retrieve channel for which there is no open connection. Connection id " + id + " existing connections " + channels.keySet().toString()); + throw new IllegalStateException("Attempt to retrieve channel for which there is no open connection. Connection id " + id + " existing connections " + channels.keySet()); return channel; } @@ -551,7 +551,7 @@ public class Selector implements Selectable { * checks if there are any staged receives and adds to completedReceives */ private void addToCompletedReceives() { - if (this.stagedReceives.size() > 0) { + if (!this.stagedReceives.isEmpty()) { Iterator<Map.Entry<KafkaChannel, Deque<NetworkReceive>>> iter = this.stagedReceives.entrySet().iterator(); while (iter.hasNext()) { Map.Entry<KafkaChannel, Deque<NetworkReceive>> entry = iter.next(); @@ -561,7 +561,7 @@ public class Selector implements Selectable { NetworkReceive networkReceive = deque.poll(); this.completedReceives.add(networkReceive); this.sensors.recordBytesReceived(channel.id(), networkReceive.payload().limit()); - if (deque.size() == 0) + if (deque.isEmpty()) iter.remove(); } } http://git-wip-us.apache.org/repos/asf/kafka/blob/e733d8c2/core/src/main/scala/kafka/network/RequestChannel.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/network/RequestChannel.scala b/core/src/main/scala/kafka/network/RequestChannel.scala index 1105802..17c5b9b 100644 --- a/core/src/main/scala/kafka/network/RequestChannel.scala +++ b/core/src/main/scala/kafka/network/RequestChannel.scala @@ -117,36 +117,39 @@ object RequestChannel extends Logging { if (apiRemoteCompleteTimeMs < 0) apiRemoteCompleteTimeMs = responseCompleteTimeMs - val requestQueueTime = (requestDequeueTimeMs - startTimeMs).max(0L) - val apiLocalTime = (apiLocalCompleteTimeMs - requestDequeueTimeMs).max(0L) - val apiRemoteTime = (apiRemoteCompleteTimeMs - apiLocalCompleteTimeMs).max(0L) - val apiThrottleTime = (responseCompleteTimeMs - apiRemoteCompleteTimeMs).max(0L) - val responseQueueTime = (responseDequeueTimeMs - responseCompleteTimeMs).max(0L) - val responseSendTime = (endTimeMs - responseDequeueTimeMs).max(0L) + val requestQueueTime = math.max(requestDequeueTimeMs - startTimeMs, 0) + val apiLocalTime = math.max(apiLocalCompleteTimeMs - requestDequeueTimeMs, 0) + val apiRemoteTime = math.max(apiRemoteCompleteTimeMs - apiLocalCompleteTimeMs, 0) + val apiThrottleTime = math.max(responseCompleteTimeMs - apiRemoteCompleteTimeMs, 0) + val responseQueueTime = math.max(responseDequeueTimeMs - responseCompleteTimeMs, 0) + val responseSendTime = math.max(endTimeMs - responseDequeueTimeMs, 0) val totalTime = endTimeMs - startTimeMs - var metricsList = List(RequestMetrics.metricsMap(ApiKeys.forId(requestId).name)) - if (requestId == ApiKeys.FETCH.id) { - val isFromFollower = requestObj.asInstanceOf[FetchRequest].isFromFollower - metricsList ::= ( if (isFromFollower) - RequestMetrics.metricsMap(RequestMetrics.followFetchMetricName) - else - RequestMetrics.metricsMap(RequestMetrics.consumerFetchMetricName) ) - } - metricsList.foreach{ - m => m.requestRate.mark() - m.requestQueueTimeHist.update(requestQueueTime) - m.localTimeHist.update(apiLocalTime) - m.remoteTimeHist.update(apiRemoteTime) - m.throttleTimeHist.update(apiThrottleTime) - m.responseQueueTimeHist.update(responseQueueTime) - m.responseSendTimeHist.update(responseSendTime) - m.totalTimeHist.update(totalTime) + val fetchMetricNames = + if (requestId == ApiKeys.FETCH.id) { + val isFromFollower = requestObj.asInstanceOf[FetchRequest].isFromFollower + Seq( + if (isFromFollower) RequestMetrics.followFetchMetricName + else RequestMetrics.consumerFetchMetricName + ) + } + else Seq.empty + val metricNames = fetchMetricNames :+ ApiKeys.forId(requestId).name + metricNames.foreach { metricName => + val m = RequestMetrics.metricsMap(metricName) + m.requestRate.mark() + m.requestQueueTimeHist.update(requestQueueTime) + m.localTimeHist.update(apiLocalTime) + m.remoteTimeHist.update(apiRemoteTime) + m.throttleTimeHist.update(apiThrottleTime) + m.responseQueueTimeHist.update(responseQueueTime) + m.responseSendTimeHist.update(responseSendTime) + m.totalTimeHist.update(totalTime) } - if(requestLogger.isTraceEnabled) + if (requestLogger.isTraceEnabled) requestLogger.trace("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s" .format(requestDesc(true), connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal)) - else if(requestLogger.isDebugEnabled) + else if (requestLogger.isDebugEnabled) requestLogger.debug("Completed request:%s from connection %s;totalTime:%d,requestQueueTime:%d,localTime:%d,remoteTime:%d,responseQueueTime:%d,sendTime:%d,securityProtocol:%s,principal:%s" .format(requestDesc(false), connectionId, totalTime, requestQueueTime, apiLocalTime, apiRemoteTime, responseQueueTime, responseSendTime, securityProtocol, session.principal)) } http://git-wip-us.apache.org/repos/asf/kafka/blob/e733d8c2/core/src/main/scala/kafka/network/SocketServer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 5c31ac6..f1ec2ef 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -31,9 +31,8 @@ import kafka.common.KafkaException import kafka.metrics.KafkaMetricsGroup import kafka.server.KafkaConfig import kafka.utils._ -import org.apache.kafka.common.MetricName import org.apache.kafka.common.metrics._ -import org.apache.kafka.common.network.{Selector => KSelector, LoginType, Mode, ChannelBuilders} +import org.apache.kafka.common.network.{ChannelBuilders, KafkaChannel, LoginType, Mode, Selector => KSelector} import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.protocol.SecurityProtocol import org.apache.kafka.common.protocol.types.SchemaException @@ -41,7 +40,7 @@ import org.apache.kafka.common.utils.{Time, Utils} import scala.collection._ import JavaConverters._ -import scala.util.control.{NonFatal, ControlThrowable} +import scala.util.control.{ControlThrowable, NonFatal} /** * An NIO socket server. The threading model is @@ -83,8 +82,6 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time val sendBufferSize = config.socketSendBufferBytes val recvBufferSize = config.socketReceiveBufferBytes - val maxRequestSize = config.socketRequestMaxBytes - val connectionsMaxIdleMs = config.connectionsMaxIdleMs val brokerId = config.brokerId var processorBeginIndex = 0 @@ -92,18 +89,8 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time val protocol = endpoint.protocolType val processorEndIndex = processorBeginIndex + numProcessorThreads - for (i <- processorBeginIndex until processorEndIndex) { - processors(i) = new Processor(i, - time, - maxRequestSize, - requestChannel, - connectionQuotas, - connectionsMaxIdleMs, - protocol, - config.values, - metrics - ) - } + for (i <- processorBeginIndex until processorEndIndex) + processors(i) = newProcessor(i, connectionQuotas, protocol) val acceptor = new Acceptor(endpoint, sendBufferSize, recvBufferSize, brokerId, processors.slice(processorBeginIndex, processorEndIndex), connectionQuotas) @@ -148,10 +135,27 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time } } + /* `protected` for test usage */ + protected[network] def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, protocol: SecurityProtocol): Processor = { + new Processor(id, + time, + config.socketRequestMaxBytes, + requestChannel, + connectionQuotas, + config.connectionsMaxIdleMs, + protocol, + config.values, + metrics + ) + } + /* For test usage */ private[network] def connectionCount(address: InetAddress): Int = Option(connectionQuotas).fold(0)(_.get(address)) + /* For test usage */ + private[network] def processor(index: Int): Processor = processors(index) + } /** @@ -376,10 +380,7 @@ private[kafka] class Processor(val id: Int, private val newConnections = new ConcurrentLinkedQueue[SocketChannel]() private val inflightResponses = mutable.Map[String, RequestChannel.Response]() - private val channelBuilder = ChannelBuilders.create(protocol, Mode.SERVER, LoginType.SERVER, channelConfigs) - private val metricTags = new util.HashMap[String, String]() - metricTags.put("networkProcessor", id.toString) - + private val metricTags = Map("networkProcessor" -> id.toString).asJava newGauge("IdlePercent", new Gauge[Double] { @@ -398,65 +399,27 @@ private[kafka] class Processor(val id: Int, "socket-server", metricTags, false, - channelBuilder) + ChannelBuilders.create(protocol, Mode.SERVER, LoginType.SERVER, channelConfigs)) override def run() { startupComplete() - while(isRunning) { + while (isRunning) { try { // setup any new connections that have been queued up configureNewConnections() // register any new responses for writing processNewResponses() - - try { - selector.poll(300) - } catch { - case e @ (_: IllegalStateException | _: IOException) => - error("Closing processor %s due to illegal state or IO exception".format(id)) - swallow(closeAll()) - shutdownComplete() - throw e - } - selector.completedReceives.asScala.foreach { receive => - try { - val channel = selector.channel(receive.source) - val session = RequestChannel.Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName), - channel.socketAddress) - val req = RequestChannel.Request(processor = id, connectionId = receive.source, session = session, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol) - requestChannel.sendRequest(req) - selector.mute(receive.source) - } catch { - case e @ (_: InvalidRequestException | _: SchemaException) => - // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier - error("Closing socket for " + receive.source + " because of error", e) - close(selector, receive.source) - } - } - - selector.completedSends.asScala.foreach { send => - val resp = inflightResponses.remove(send.destination).getOrElse { - throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`") - } - resp.request.updateRequestMetrics() - selector.unmute(send.destination) - } - - selector.disconnected.asScala.foreach { connectionId => - val remoteHost = ConnectionId.fromString(connectionId).getOrElse { - throw new IllegalStateException(s"connectionId has unexpected format: $connectionId") - }.remoteHost - // the channel has been closed by the selector but the quotas still need to be updated - connectionQuotas.dec(InetAddress.getByName(remoteHost)) - } - + poll() + processCompletedReceives() + processCompletedSends() + processDisconnected() } catch { // We catch all the throwables here to prevent the processor thread from exiting. We do this because - // letting a processor exit might cause bigger impact on the broker. Usually the exceptions thrown would + // letting a processor exit might cause a bigger impact on the broker. Usually the exceptions thrown would // be either associated with a specific socket channel or a bad request. We just ignore the bad socket channel // or request. This behavior might need to be reviewed if we see an exception that need the entire broker to stop. - case e : ControlThrowable => throw e - case e : Throwable => + case e: ControlThrowable => throw e + case e: Throwable => error("Processor got uncaught exception.", e) } } @@ -468,7 +431,7 @@ private[kafka] class Processor(val id: Int, private def processNewResponses() { var curr = requestChannel.receiveResponse(id) - while(curr != null) { + while (curr != null) { try { curr.responseAction match { case RequestChannel.NoOpAction => @@ -478,9 +441,7 @@ private[kafka] class Processor(val id: Int, trace("Socket server received empty response to send, registering for read: " + curr) selector.unmute(curr.request.connectionId) case RequestChannel.SendAction => - trace("Socket server received response to send, registering for write and sending data: " + curr) - selector.send(curr.responseSend) - inflightResponses += (curr.request.connectionId -> curr) + sendResponse(curr) case RequestChannel.CloseConnectionAction => curr.request.updateRequestMetrics trace("Closing socket connection actively according to the response code.") @@ -492,6 +453,71 @@ private[kafka] class Processor(val id: Int, } } + /* `protected` for test usage */ + protected[network] def sendResponse(response: RequestChannel.Response) { + trace(s"Socket server received response to send, registering for write and sending data: $response") + val channel = selector.channel(response.responseSend.destination) + // `channel` can be null if the selector closed the connection because it was idle for too long + if (channel == null) { + warn(s"Attempting to send response via channel for which there is no open connection, connection id $id") + response.request.updateRequestMetrics() + } + else { + selector.send(response.responseSend) + inflightResponses += (response.request.connectionId -> response) + } + } + + private def poll() { + try selector.poll(300) + catch { + case e @ (_: IllegalStateException | _: IOException) => + error(s"Closing processor $id due to illegal state or IO exception") + swallow(closeAll()) + shutdownComplete() + throw e + } + } + + private def processCompletedReceives() { + selector.completedReceives.asScala.foreach { receive => + try { + val channel = selector.channel(receive.source) + val session = RequestChannel.Session(new KafkaPrincipal(KafkaPrincipal.USER_TYPE, channel.principal.getName), + channel.socketAddress) + val req = RequestChannel.Request(processor = id, connectionId = receive.source, session = session, buffer = receive.payload, startTimeMs = time.milliseconds, securityProtocol = protocol) + requestChannel.sendRequest(req) + selector.mute(receive.source) + } catch { + case e @ (_: InvalidRequestException | _: SchemaException) => + // note that even though we got an exception, we can assume that receive.source is valid. Issues with constructing a valid receive object were handled earlier + error(s"Closing socket for ${receive.source} because of error", e) + close(selector, receive.source) + } + } + } + + private def processCompletedSends() { + selector.completedSends.asScala.foreach { send => + val resp = inflightResponses.remove(send.destination).getOrElse { + throw new IllegalStateException(s"Send for ${send.destination} completed, but not in `inflightResponses`") + } + resp.request.updateRequestMetrics() + selector.unmute(send.destination) + } + } + + private def processDisconnected() { + selector.disconnected.asScala.foreach { connectionId => + val remoteHost = ConnectionId.fromString(connectionId).getOrElse { + throw new IllegalStateException(s"connectionId has unexpected format: $connectionId") + }.remoteHost + inflightResponses.remove(connectionId).foreach(_.request.updateRequestMetrics()) + // the channel has been closed by the selector but the quotas still need to be updated + connectionQuotas.dec(InetAddress.getByName(remoteHost)) + } + } + /** * Queue up a new connection for reading */ @@ -504,10 +530,10 @@ private[kafka] class Processor(val id: Int, * Register any new connections that have been queued up */ private def configureNewConnections() { - while(!newConnections.isEmpty) { + while (!newConnections.isEmpty) { val channel = newConnections.poll() try { - debug("Processor " + id + " listening to new connection from " + channel.socket.getRemoteSocketAddress) + debug(s"Processor $id listening to new connection from ${channel.socket.getRemoteSocketAddress}") val localHost = channel.socket().getLocalAddress.getHostAddress val localPort = channel.socket().getLocalPort val remoteHost = channel.socket().getInetAddress.getHostAddress @@ -515,12 +541,12 @@ private[kafka] class Processor(val id: Int, val connectionId = ConnectionId(localHost, localPort, remoteHost, remotePort).toString selector.register(connectionId, channel) } catch { - // We explicitly catch all non fatal exceptions and close the socket to avoid socket leak. The other - // throwables will be caught in processor and logged as uncaught exception. + // We explicitly catch all non fatal exceptions and close the socket to avoid a socket leak. The other + // throwables will be caught in processor and logged as uncaught exceptions. case NonFatal(e) => - // need to close the channel here to avoid socket leak. + // need to close the channel here to avoid a socket leak. close(channel) - error("Processor " + id + " closed connection from " + channel.getRemoteAddress, e) + error(s"Processor $id closed connection from ${channel.getRemoteAddress}", e) } } } @@ -535,6 +561,9 @@ private[kafka] class Processor(val id: Int, selector.close() } + /* For test usage */ + private[network] def channel(connectionId: String): Option[KafkaChannel] = + Option(selector.channel(connectionId)) /** * Wakeup the thread for selection. http://git-wip-us.apache.org/repos/asf/kafka/blob/e733d8c2/core/src/test/scala/unit/kafka/network/SocketServerTest.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 5d28894..81e5232 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -39,7 +39,7 @@ import org.junit.Assert._ import org.junit._ import org.scalatest.junit.JUnitSuite -import scala.collection.Map +import scala.collection.mutable.ArrayBuffer class SocketServerTest extends JUnitSuite { val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) @@ -55,6 +55,7 @@ class SocketServerTest extends JUnitSuite { val metrics = new Metrics val server = new SocketServer(config, metrics, new SystemTime) server.startup() + val sockets = new ArrayBuffer[Socket] def sendRequest(socket: Socket, request: Array[Byte], id: Option[Short] = None) { val outgoing = new DataOutputStream(socket.getOutputStream) @@ -79,7 +80,12 @@ class SocketServerTest extends JUnitSuite { /* A simple request handler that just echos back the response */ def processRequest(channel: RequestChannel) { - val request = channel.receiveRequest + val request = channel.receiveRequest(2000) + assertNotNull("receiveRequest timed out", request) + processRequest(channel, request) + } + + def processRequest(channel: RequestChannel, request: RequestChannel.Request) { val byteBuffer = ByteBuffer.allocate(request.header.sizeOf + request.body.sizeOf) request.header.writeTo(byteBuffer) request.body.writeTo(byteBuffer) @@ -89,13 +95,18 @@ class SocketServerTest extends JUnitSuite { channel.sendResponse(new RequestChannel.Response(request.processor, request, send)) } - def connect(s: SocketServer = server, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT) = - new Socket("localhost", server.boundPort(protocol)) + def connect(s: SocketServer = server, protocol: SecurityProtocol = SecurityProtocol.PLAINTEXT) = { + val socket = new Socket("localhost", s.boundPort(protocol)) + sockets += socket + socket + } @After - def cleanup() { + def tearDown() { metrics.close() server.shutdown() + sockets.foreach(_.close()) + sockets.clear() } private def producerRequestBytes: Array[Byte] = { @@ -183,7 +194,7 @@ class SocketServerTest extends JUnitSuite { @Test def testMaxConnectionsPerIp() { - // make the maximum allowable number of connections and then leak them + // make the maximum allowable number of connections val conns = (0 until server.config.maxConnectionsPerIp).map(_ => connect()) // now try one more (should fail) val conn = connect() @@ -201,27 +212,30 @@ class SocketServerTest extends JUnitSuite { sendRequest(conn2, serializedBytes) val request = server.requestChannel.receiveRequest(2000) assertNotNull(request) - conn2.close() - conns.tail.foreach(_.close()) } @Test - def testMaxConnectionsPerIPOverrides() { - val overrideNum = 6 - val overrides = Map("localhost" -> overrideNum) + def testMaxConnectionsPerIpOverrides() { + val overrideNum = server.config.maxConnectionsPerIp + 1 val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + overrideProps.put(KafkaConfig.MaxConnectionsPerIpOverridesProp, s"localhost:$overrideNum") val serverMetrics = new Metrics() - val overrideServer: SocketServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime()) + val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime()) try { overrideServer.startup() - // make the maximum allowable number of connections and then leak them - val conns = ((0 until overrideNum).map(i => connect(overrideServer))) + // make the maximum allowable number of connections + val conns = (0 until overrideNum).map(_ => connect(overrideServer)) + + // it should succeed + val serializedBytes = producerRequestBytes + sendRequest(conns.last, serializedBytes) + val request = overrideServer.requestChannel.receiveRequest(2000) + assertNotNull(request) + // now try one more (should fail) val conn = connect(overrideServer) conn.setSoTimeout(3000) assertEquals(-1, conn.getInputStream.read()) - conn.close() - conns.foreach(_.close()) } finally { overrideServer.shutdown() serverMetrics.close() @@ -229,16 +243,16 @@ class SocketServerTest extends JUnitSuite { } @Test - def testSslSocketServer(): Unit = { + def testSslSocketServer() { val trustStoreFile = File.createTempFile("truststore", ".jks") val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL), trustStoreFile = Some(trustStoreFile)) overrideProps.put(KafkaConfig.ListenersProp, "SSL://localhost:0") val serverMetrics = new Metrics - val overrideServer: SocketServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime) - overrideServer.startup() + val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, new SystemTime) try { + overrideServer.startup() val sslContext = SSLContext.getInstance("TLSv1.2") sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom()) val socketFactory = sslContext.getSocketFactory @@ -271,12 +285,95 @@ class SocketServerTest extends JUnitSuite { } @Test - def testSessionPrincipal(): Unit = { + def testSessionPrincipal() { val socket = connect() val bytes = new Array[Byte](40) sendRequest(socket, bytes, Some(0)) - assertEquals(KafkaPrincipal.ANONYMOUS, server.requestChannel.receiveRequest().session.principal) - socket.close() + assertEquals(KafkaPrincipal.ANONYMOUS, server.requestChannel.receiveRequest(2000).session.principal) + } + + /* Test that we update request metrics if the client closes the connection while the broker response is in flight. */ + @Test + def testClientDisconnectionUpdatesRequestMetrics() { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + val serverMetrics = new Metrics + var conn: Socket = null + val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, new SystemTime) { + override def newProcessor(id: Int, connectionQuotas: ConnectionQuotas, protocol: SecurityProtocol): Processor = { + new Processor(id, time, config.socketRequestMaxBytes, requestChannel, connectionQuotas, + config.connectionsMaxIdleMs, protocol, config.values, metrics) { + override protected[network] def sendResponse(response: RequestChannel.Response) { + conn.close() + super.sendResponse(response) + } + } + } + } + try { + overrideServer.startup() + conn = connect(overrideServer) + val serializedBytes = producerRequestBytes + sendRequest(conn, serializedBytes) + + val channel = overrideServer.requestChannel + val request = channel.receiveRequest(2000) + + val requestMetrics = RequestMetrics.metricsMap(ApiKeys.forId(request.requestId).name) + def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count + val expectedTotalTimeCount = totalTimeHistCount() + 1 + + // send a large buffer to ensure that the broker detects the client disconnection while writing to the socket channel. + // On Mac OS X, the initial write seems to always succeed and it is able to write up to 102400 bytes on the initial + // write. If the buffer is smaller than this, the write is considered complete and the disconnection is not + // detected. If the buffer is larger than 102400 bytes, a second write is attempted and it fails with an + // IOException. + val send = new NetworkSend(request.connectionId, ByteBuffer.allocate(550000)) + channel.sendResponse(new RequestChannel.Response(request.processor, request, send)) + TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount, + s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}") + + } finally { + overrideServer.shutdown() + serverMetrics.close() + } + } + + /* + * Test that we update request metrics if the channel has been removed from the selector when the broker calls + * `selector.send` (selector closes old connections, for example). + */ + @Test + def testBrokerSendAfterChannelClosedUpdatesRequestMetrics() { + val props = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, port = 0) + props.setProperty(KafkaConfig.ConnectionsMaxIdleMsProp, "100") + val serverMetrics = new Metrics + var conn: Socket = null + val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, new SystemTime) + try { + overrideServer.startup() + conn = connect(overrideServer) + val serializedBytes = producerRequestBytes + sendRequest(conn, serializedBytes) + val channel = overrideServer.requestChannel + val request = channel.receiveRequest(2000) + + TestUtils.waitUntilTrue(() => overrideServer.processor(request.processor).channel(request.connectionId).isEmpty, + s"Idle connection `${request.connectionId}` was not closed by selector") + + val requestMetrics = RequestMetrics.metricsMap(ApiKeys.forId(request.requestId).name) + def totalTimeHistCount(): Long = requestMetrics.totalTimeHist.count + val expectedTotalTimeCount = totalTimeHistCount() + 1 + + processRequest(channel, request) + + TestUtils.waitUntilTrue(() => totalTimeHistCount() == expectedTotalTimeCount, + s"request metrics not updated, expected: $expectedTotalTimeCount, actual: ${totalTimeHistCount()}") + + } finally { + overrideServer.shutdown() + serverMetrics.close() + } + } }
