chirag-wadhwa5 commented on code in PR #16456:
URL: https://github.com/apache/kafka/pull/16456#discussion_r1665597002


##########
core/src/main/scala/kafka/server/KafkaApis.scala:
##########
@@ -3955,11 +3960,490 @@ class KafkaApis(val requestChannel: RequestChannel,
     }
   }
 
+  /**
+   * Handle a shareFetch request
+   */
   def handleShareFetchRequest(request: RequestChannel.Request): Unit = {
     val shareFetchRequest = request.body[ShareFetchRequest]
-    // TODO: Implement the ShareFetchRequest handling
-    requestHelper.sendMaybeThrottle(request, 
shareFetchRequest.getErrorResponse(Errors.UNSUPPORTED_VERSION.exception))
-    CompletableFuture.completedFuture[Unit](())
+
+    if (!config.isNewGroupCoordinatorEnabled) {
+      // The API is not supported by the "old" group coordinator (the 
default). If the
+      // new one is not enabled, we fail directly here.
+      requestHelper.sendMaybeThrottle(request, 
shareFetchRequest.getErrorResponse(Errors.UNSUPPORTED_VERSION.exception))
+      CompletableFuture.completedFuture[Unit](())
+      return
+    } else if (!config.isShareGroupEnabled) {
+      // The API is not supported when the "share" rebalance protocol has not 
been set explicitly
+      requestHelper.sendMaybeThrottle(request, 
shareFetchRequest.getErrorResponse(Errors.UNSUPPORTED_VERSION.exception))
+      CompletableFuture.completedFuture[Unit](())
+      return
+    }
+    val sharePartitionManager : SharePartitionManager = 
this.sharePartitionManager match {
+      case Some(manager) => manager
+      case None => throw new IllegalStateException("ShareFetchRequest received 
but SharePartitionManager is not initialized")
+    }
+
+    val groupId = shareFetchRequest.data.groupId
+    val memberId = shareFetchRequest.data.memberId
+    val shareSessionEpoch = shareFetchRequest.data.shareSessionEpoch
+
+    var cachedTopicPartitions : util.List[TopicIdPartition] = null
+
+    if (shareSessionEpoch == ShareFetchMetadata.FINAL_EPOCH) {
+      try {
+        cachedTopicPartitions = 
sharePartitionManager.cachedTopicIdPartitionsInShareSession(groupId, 
Uuid.fromString(memberId))
+      } catch {
+        // Exception handling is needed when this value is being utilized on 
receiving FINAL_EPOCH.
+        case _: ShareSessionNotFoundException => cachedTopicPartitions = null
+      }
+    }
+
+    def isAcknowledgeDataPresentInFetchRequest() : Boolean = {
+      var isAcknowledgeDataPresent = false
+      shareFetchRequest.data.topics.forEach ( topic => {
+        breakable{
+          topic.partitions.forEach ( partition => {
+            if (partition.acknowledgementBatches != null && 
!partition.acknowledgementBatches.isEmpty) {
+              isAcknowledgeDataPresent = true
+              break
+            } else {
+              isAcknowledgeDataPresent = false
+            }
+          })
+        }
+      })
+      isAcknowledgeDataPresent
+    }
+
+    val isAcknowledgeDataPresent = isAcknowledgeDataPresentInFetchRequest()
+
+    def isInvalidShareFetchRequest() : Boolean = {
+      // The Initial Share Fetch Request should not Acknowledge any data
+      if (shareSessionEpoch == ShareFetchMetadata.INITIAL_EPOCH && 
isAcknowledgeDataPresent) {
+        return true
+      }
+      false
+    }
+
+    val topicNames = metadataCache.topicIdsToNames()
+    val shareFetchData = shareFetchRequest.shareFetchData(topicNames)
+    val forgottenTopics = shareFetchRequest.forgottenTopics(topicNames)
+
+    val newReqMetadata : ShareFetchMetadata = new 
ShareFetchMetadata(Uuid.fromString(memberId), shareSessionEpoch)
+    var shareFetchContext : ShareFetchContext = null
+
+    var shareFetchResponse : ShareFetchResponse = null
+
+    // check if the Request is Invalid
+    if(isInvalidShareFetchRequest()) {
+      shareFetchResponse = 
shareFetchRequest.getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, 
Errors.INVALID_REQUEST.exception) match {
+        case response: ShareFetchResponse => response
+        case _ => null
+      }
+    }
+
+    try {
+      // Creating the shareFetchContext for Share Session Handling
+      shareFetchContext = sharePartitionManager.newContext(groupId, 
shareFetchData, forgottenTopics, newReqMetadata)
+    } catch {
+      case e: Exception => shareFetchResponse = 
shareFetchRequest.getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, e) 
match {
+        case response: ShareFetchResponse => response
+        case _ => null
+      }
+    }
+
+    // Variable to store any error thrown while the handling piggybacked 
acknowledgements
+    var acknowledgeError : Errors = Errors.NONE
+    // Variable to store the topic partition wise result of piggybacked 
acknowledgements
+    var acknowledgeResult = mutable.Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]()
+
+    // This check is done to make sure that there was no Share Session related 
error while creating shareFetchContext
+    if(shareFetchResponse == null) {
+      val erroneousAndValidPartitionData : ErroneousAndValidPartitionData = 
shareFetchContext.getErroneousAndValidTopicIdPartitions
+      val topicIdPartitionSeq : mutable.Set[TopicIdPartition] = mutable.Set()
+      erroneousAndValidPartitionData.erroneous.forEach {
+        case(tp, _) => if (!topicIdPartitionSeq.contains(tp)) 
topicIdPartitionSeq += tp
+      }
+      erroneousAndValidPartitionData.validTopicIdPartitions.forEach {
+        case(tp, _) => if (!topicIdPartitionSeq.contains(tp)) 
topicIdPartitionSeq += tp
+      }
+      shareFetchData.forEach {
+        case(tp, _) => if (!topicIdPartitionSeq.contains(tp)) 
topicIdPartitionSeq += tp
+      }
+
+      val authorizedTopics = authHelper.filterByAuthorized(
+        request.context,
+        READ,
+        TOPIC,
+        topicIdPartitionSeq
+      )(_.topicPartition.topic)
+
+      // Handling the Acknowledgements from the ShareFetchRequest
+      // If this check is true, then we are sure that this is not an Initial 
ShareFetch Request, otherwise the request would have been invalid
+      if(isAcknowledgeDataPresent) {
+        if (!authHelper.authorize(request.context, READ, GROUP, groupId)) {
+          acknowledgeError = Errors.GROUP_AUTHORIZATION_FAILED
+        } else {
+          acknowledgeResult = handleAcknowledgements(request, topicNames, 
sharePartitionManager, authorizedTopics, groupId, memberId, true)
+        }
+      }
+
+      // Handling the Fetch from the ShareFetchRequest
+      try {
+        shareFetchResponse = handleFetchFromShareFetchRequest(
+          request,
+          erroneousAndValidPartitionData,
+          topicNames,
+          sharePartitionManager,
+          shareFetchContext,
+          authorizedTopics
+        )
+      } catch {
+        case throwable : Throwable =>
+          debug(s"Share fetch request with correlation from client 
${request.header.clientId}  " +
+            s"failed with error ${throwable.getMessage}")
+          requestHelper.handleError(request, throwable)
+          return
+      }
+    }
+
+    def combineShareFetchAndShareAcknowledgeResponses(
+                                                       shareFetchResponse: 
ShareFetchResponse,
+                                                       acknowledgeResult : 
mutable.Map[TopicIdPartition, ShareAcknowledgeResponseData.PartitionData],
+                                                       acknowledgeError : 
Errors
+                                                     ) : ShareFetchResponse = {
+
+      // The outer map has topicId as the key and the inner map has 
partitionIndex as the key
+      val topicPartitionAcknowledgements : mutable.Map[Uuid, mutable.Map[Int, 
Short]] = mutable.Map()
+      if(acknowledgeResult != null && acknowledgeResult.nonEmpty) {
+        acknowledgeResult.asJava.forEach { (tp, partitionData) =>
+          topicPartitionAcknowledgements.get(tp.topicId) match {
+            case Some(subMap) =>
+              subMap += tp.partition -> partitionData.errorCode
+            case None =>
+              val partitionAcknowledgementsMap : mutable.Map[Int, Short] = 
mutable.Map()
+              partitionAcknowledgementsMap += tp.partition -> 
partitionData.errorCode
+              topicPartitionAcknowledgements += tp.topicId -> 
partitionAcknowledgementsMap
+          }
+        }
+      }
+
+      shareFetchResponse.data.responses.forEach(topic => {
+        val topicId = topic.topicId
+        topicPartitionAcknowledgements.get(topicId) match {
+          case Some(subMap) =>
+            topic.partitions.forEach { partition =>
+              subMap.get(partition.partitionIndex) match {
+                case Some(value) =>
+                  val ackErrorCode = if(acknowledgeError.code != 
Errors.NONE.code) acknowledgeError.code else value
+                  partition.setAcknowledgeErrorCode(ackErrorCode)
+                  // Delete the element
+                  subMap.remove(partition.partitionIndex)
+                case None =>
+              }
+            }
+            // Add the remaining acknowledgements
+            subMap.foreach { case (partitionIndex, value) =>
+              val ackErrorCode = if(acknowledgeError.code != Errors.NONE.code) 
acknowledgeError.code else value
+              val fetchPartitionData = new 
ShareFetchResponseData.PartitionData()
+                .setPartitionIndex(partitionIndex)
+                .setErrorCode(Errors.NONE.code)
+                .setAcknowledgeErrorCode(ackErrorCode)
+              topic.partitions.add(fetchPartitionData)
+            }
+            topicPartitionAcknowledgements.remove(topicId)
+          case None =>
+        }
+      })
+      // Add the remaining acknowledgements
+      topicPartitionAcknowledgements.foreach{ case(topicId, subMap) =>
+        val topicData = new 
ShareFetchResponseData.ShareFetchableTopicResponse()
+          .setTopicId(topicId)
+        subMap.foreach { case (partitionIndex, value) =>
+          val ackErrorCode = if(acknowledgeError.code != Errors.NONE.code) 
acknowledgeError.code else value
+          val fetchPartitionData = new ShareFetchResponseData.PartitionData()
+            .setPartitionIndex(partitionIndex)
+            .setErrorCode(Errors.NONE.code)
+            .setAcknowledgeErrorCode(ackErrorCode)
+          topicData.partitions.add(fetchPartitionData)
+        }
+        shareFetchResponse.data.responses.add(topicData)
+      }
+
+      if (shareSessionEpoch == ShareFetchMetadata.FINAL_EPOCH && 
cachedTopicPartitions != null) {
+        sharePartitionManager.releaseAcquiredRecords(groupId, memberId, 
cachedTopicPartitions).
+          whenComplete((releaseAcquiredRecordsData, throwable) => {
+            if (throwable != null) {
+              debug(s"Release acquired records on share session close with 
correlation from client ${request.header.clientId}  " +
+                s"failed with error ${throwable.getMessage}")
+              requestHelper.handleError(request, throwable)
+            } else {
+              info(s"Release acquired records on share session close 
$releaseAcquiredRecordsData succeeded")
+            }
+          })
+      }
+      shareFetchResponse
+    }
+
+    def updateConversionStats(send: Send): Unit = {
+      send match {
+        case send: MultiRecordsSend if send.recordConversionStats != null =>
+          send.recordConversionStats.asScala.toMap.foreach {
+            case (tp, stats) => updateRecordConversionStats(request, tp, stats)
+          }
+        case send: NetworkSend =>
+          updateConversionStats(send.send())
+        case _ =>
+      }
+    }
+
+    // Send the response immediately.
+    requestChannel.sendResponse(request, 
combineShareFetchAndShareAcknowledgeResponses(shareFetchResponse, 
acknowledgeResult, acknowledgeError), Some(updateConversionStats))
+  }
+
+  def handleFetchFromShareFetchRequest(request: RequestChannel.Request,
+                                       erroneousAndValidPartitionData : 
ErroneousAndValidPartitionData,
+                                       topicNames : util.Map[Uuid, String],
+                                       sharePartitionManager : 
SharePartitionManager,
+                                       shareFetchContext : ShareFetchContext,
+                                       authorizedTopics: Set[String]
+                                      ): ShareFetchResponse = {
+
+    val erroneous = mutable.ArrayBuffer[(TopicIdPartition, 
ShareFetchResponseData.PartitionData)]()
+    // Regular Kafka consumers need READ permission on each partition they are 
fetching.
+    val partitionDatas = new mutable.ArrayBuffer[(TopicIdPartition, 
ShareFetchRequest.SharePartitionData)]
+    erroneousAndValidPartitionData.erroneous.forEach {
+      erroneousData => erroneous += erroneousData
+    }
+    erroneousAndValidPartitionData.validTopicIdPartitions.forEach {
+      validPartitionData => partitionDatas += validPartitionData
+    }
+
+    val interestingWithMaxBytes = new util.LinkedHashMap[TopicIdPartition, 
Integer]
+
+    partitionDatas.foreach { case (topicIdPartition, sharePartitionData) =>
+      if (!authorizedTopics.contains(topicIdPartition.topicPartition.topic))
+        erroneous += topicIdPartition -> 
ShareFetchResponse.partitionResponse(topicIdPartition, 
Errors.TOPIC_AUTHORIZATION_FAILED)
+      else if (!metadataCache.contains(topicIdPartition.topicPartition))
+        erroneous += topicIdPartition -> 
ShareFetchResponse.partitionResponse(topicIdPartition, 
Errors.UNKNOWN_TOPIC_OR_PARTITION)
+      else
+        interestingWithMaxBytes.put(topicIdPartition, 
sharePartitionData.maxBytes)
+    }
+
+    val clientId = request.header.clientId
+
+    def maybeConvertShareFetchedData(tp: TopicIdPartition,
+                                     partitionData: 
ShareFetchResponseData.PartitionData): ShareFetchResponseData.PartitionData = {
+      val unconvertedRecords = ShareFetchResponse.recordsOrFail(partitionData)
+      new ShareFetchResponseData.PartitionData()
+        .setPartitionIndex(tp.partition)
+        .setErrorCode(Errors.forCode(partitionData.errorCode).code)
+        .setRecords(unconvertedRecords)
+        .setAcquiredRecords(partitionData.acquiredRecords)
+        .setCurrentLeader(partitionData.currentLeader)
+    }
+
+    val shareFetchRequest = request.body[ShareFetchRequest]
+
+    val versionId = request.header.apiVersion
+    val groupId = shareFetchRequest.data.groupId
+    val memberId = shareFetchRequest.data.memberId
+
+    // the callback for processing a share fetch response, invoked before 
throttling
+    def processResponseCallback(responsePartitionData: Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]): ShareFetchResponse = {
+      val partitions = new util.LinkedHashMap[TopicIdPartition, 
ShareFetchResponseData.PartitionData]
+      val nodeEndpoints = new mutable.HashMap[Int, Node]
+      responsePartitionData.foreach { case(tp, partitionData) =>
+        partitionData.errorCode match {
+          case errCode if errCode == Errors.NOT_LEADER_OR_FOLLOWER.code | 
errCode == Errors.FENCED_LEADER_EPOCH.code =>
+            val leaderNode = getCurrentLeader(tp.topicPartition, 
request.context.listenerName)
+            leaderNode.node.foreach { node =>
+              nodeEndpoints.put(node.id, node)
+            }
+            partitionData.currentLeader
+              .setLeaderId(leaderNode.leaderId)
+              .setLeaderEpoch(leaderNode.leaderEpoch)
+          case _ =>
+        }
+
+        partitions.put(tp, partitionData)
+      }
+      erroneous.foreach { case (tp, data) => partitions.put(tp, data) }
+
+      var unconvertedShareFetchResponse: ShareFetchResponse = null
+
+      def createResponse(throttleTimeMs: Int): ShareFetchResponse = {
+        // Down-convert messages for each partition if required
+        val convertedData = new util.LinkedHashMap[TopicIdPartition, 
ShareFetchResponseData.PartitionData]
+        unconvertedShareFetchResponse.data.responses.forEach { topicResponse =>
+          topicResponse.partitions.forEach { unconvertedPartitionData =>
+            val tp = new TopicIdPartition(topicResponse.topicId, new 
TopicPartition(topicNames.get(topicResponse.topicId),
+              unconvertedPartitionData.partitionIndex))
+            val error = Errors.forCode(unconvertedPartitionData.errorCode)
+            if (error != Errors.NONE)
+              debug(s"Share Fetch request with correlation id 
${request.header.correlationId} from client $clientId " +
+                s"on partition $tp failed due to ${error.exceptionName}")
+            convertedData.put(tp, maybeConvertShareFetchedData(tp, 
unconvertedPartitionData))
+          }
+        }
+
+        // Prepare share fetch response from converted data
+        val response =
+          ShareFetchResponse.of(unconvertedShareFetchResponse.error, 
throttleTimeMs, convertedData, nodeEndpoints.values.toList.asJava)
+        // record the bytes out metrics only when the response is being sent
+        response.data.responses.forEach { topicResponse =>
+          topicResponse.partitions.forEach { data =>
+            // If the topic name was not known, we will have no bytes out.
+            if (topicResponse.topicId != null) {
+              val tp = new TopicIdPartition(topicResponse.topicId, new 
TopicPartition(topicNames.get(topicResponse.topicId), data.partitionIndex))
+              brokerTopicStats.updateBytesOut(tp.topic, false, false, 
ShareFetchResponse.recordsSize(data))
+            }
+          }
+        }
+        response
+      }
+      // Share Fetch size used to determine throttle time is calculated before 
any down conversions.
+      // This may be slightly different from the actual response size. But 
since down conversions
+      // result in data being loaded into memory, we should do this only when 
we are not going to throttle.
+      //
+      // Record both bandwidth and request quota-specific values and throttle 
by muting the channel if any of the
+      // quotas have been violated. If both quotas have been violated, use the 
max throttle time between the two
+      // quotas. When throttled, we unrecord the recorded bandwidth quota value
+      val responseSize = shareFetchContext.responseSize(partitions, versionId)
+      val timeMs = time.milliseconds()
+      val requestThrottleTimeMs = 
quotas.request.maybeRecordAndGetThrottleTimeMs(request, timeMs)
+      val bandwidthThrottleTimeMs = 
quotas.fetch.maybeRecordAndGetThrottleTimeMs(request, responseSize, timeMs)
+
+      val maxThrottleTimeMs = math.max(bandwidthThrottleTimeMs, 
requestThrottleTimeMs)
+      if (maxThrottleTimeMs > 0) {
+        request.apiThrottleTimeMs = maxThrottleTimeMs
+        // Even if we need to throttle for request quota violation, we should 
"unrecord" the already recorded value
+        // from the fetch quota because we are going to return an empty 
response.
+        quotas.fetch.unrecordQuotaSensor(request, responseSize, timeMs)
+        if (bandwidthThrottleTimeMs > requestThrottleTimeMs) {
+          requestHelper.throttle(quotas.fetch, request, 
bandwidthThrottleTimeMs)
+        } else {
+          requestHelper.throttle(quotas.request, request, 
requestThrottleTimeMs)
+        }
+        // If throttling is required, return an empty response.
+        unconvertedShareFetchResponse = 
shareFetchContext.throttleResponse(maxThrottleTimeMs)
+      } else {
+        // Get the actual response. This will update the fetch context.
+        unconvertedShareFetchResponse = 
shareFetchContext.updateAndGenerateResponseData(groupId, 
Uuid.fromString(memberId), partitions)
+        val responsePartitionsSize = 
unconvertedShareFetchResponse.data.responses.stream().mapToInt(_.partitions.size()).sum()
+        trace(s"Sending Share Fetch response with 
partitions.size=$responsePartitionsSize")

Review Comment:
   Thanks for the review. I think I copied this comment from the regular fetch 
request implementation. But yeah its very trivial. I will make the change in 
the next commit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: jira-unsubscr...@kafka.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to