xiangfu0 commented on code in PR #17811:
URL: https://github.com/apache/pinot/pull/17811#discussion_r2897664798


##########
pinot-spi/src/main/java/org/apache/pinot/spi/stream/PartitionGroupMetadataFetcher.java:
##########
@@ -122,29 +140,53 @@ private Boolean fetchMultipleStreams()
               .collect(Collectors.toList());
       try (StreamMetadataProvider streamMetadataProvider = 
streamConsumerFactory.createStreamMetadataProvider(
           StreamConsumerFactory.getUniqueClientId(clientId))) {
-        _newPartitionGroupMetadataList.addAll(
+        List<PartitionGroupMetadata> partitionGroupMetadataList =
             streamMetadataProvider.computePartitionGroupMetadata(clientId,
-                    streamConfig, topicPartitionGroupConsumptionStatusList, 
/*maxWaitTimeMs=*/15000,
+                    streamConfig, topicPartitionGroupConsumptionStatusList,
+                    /*maxWaitTimeMs=*/METADATA_FETCH_TIMEOUT_MS,
                     _forceGetOffsetFromStream)
                 .stream()
                 .map(metadata -> new PartitionGroupMetadata(
                     
IngestionConfigUtils.getPinotPartitionIdFromStreamPartitionId(metadata.getPartitionGroupId(),
-                        index), metadata.getStartOffset()))
-                .collect(Collectors.toList()));
-        if (_exception != null) {
-          // We had at least one failure, but succeeded now. Log an info
-          LOGGER.info("Successfully retrieved PartitionGroupMetadata for topic 
{}", topicName);
-        }
+                        index), metadata.getStartOffset(), 
metadata.getSequenceNumber()))
+                .collect(Collectors.toList());
+        int partitionCount = getNumPartitions(streamMetadataProvider, 
partitionGroupMetadataList);
+        _streamMetadataList.add(
+            new StreamMetadata(streamConfig, partitionCount, 
partitionGroupMetadataList));
       } catch (TransientConsumerException e) {
-        LOGGER.warn("Transient Exception: Could not get partition count for 
topic {}", topicName, e);
+        LOGGER.warn("Transient Exception: Could not get StreamMetadata for 
topic {}", topicName, e);
         _exception = e;
         return Boolean.FALSE;
       } catch (Exception e) {
-        LOGGER.warn("Could not get partition count for topic {}", topicName, 
e);
+        LOGGER.warn("Could not get StreamMetadata for topic {}", topicName, e);
         _exception = e;
         throw e;
       }
     }
     return Boolean.TRUE;
   }
+
+  private int getNumPartitions(StreamMetadataProvider streamMetadataProvider,
+      List<PartitionGroupMetadata> partitionGroupMetadataList) {
+    if (usesDefaultComputePartitionGroupMetadata(streamMetadataProvider)) {
+      return partitionGroupMetadataList.size();
+    }
+    return 
streamMetadataProvider.fetchPartitionCount(/*timeoutMillis=*/METADATA_FETCH_TIMEOUT_MS);
+  }
+
+  private boolean 
usesDefaultComputePartitionGroupMetadata(StreamMetadataProvider 
streamMetadataProvider) {
+    Class<?> providerClass = streamMetadataProvider.getClass();
+    return isDefaultComputeMethod(providerClass, 
COMPUTE_PARTITION_GROUP_METADATA_ARGUMENT_TYPES)
+        && isDefaultComputeMethod(providerClass, 
COMPUTE_PARTITION_GROUP_METADATA_WITH_FORCE_ARGUMENT_TYPES);

Review Comment:
   Good catch. Updated `usesDefaultComputePartitionGroupMetadata()` to inspect 
only the 5-arg `computePartitionGroupMetadata(...)` overload because that is 
the signature this fetcher invokes. Also added 
`testFourArgOverrideWithDefaultFiveArgAvoidsExtraPartitionCountFetch` to cover 
the "4-arg override + default 5-arg" provider case and prevent extra 
`fetchPartitionCount()` calls.



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/realtime/PinotLLCRealtimeSegmentManager.java:
##########
@@ -1704,12 +1723,12 @@ private boolean isAllInstancesInState(Map<String, 
String> instanceStateMap, Stri
    */
   @VisibleForTesting
   IdealState ensureAllPartitionsConsuming(TableConfig tableConfig, 
List<StreamConfig> streamConfigs,
-      IdealState idealState, List<PartitionGroupMetadata> 
partitionGroupMetadataList, OffsetCriteria offsetCriteria) {
+      IdealState idealState, List<StreamMetadata> streamMetadataList, 
OffsetCriteria offsetCriteria) {
     String realtimeTableName = tableConfig.getTableName();
 
     InstancePartitions instancePartitions = 
getConsumingInstancePartitions(tableConfig);
     int numReplicas = getNumReplicas(tableConfig, instancePartitions);
-    int numPartitions = partitionGroupMetadataList.size();
+    int numPartitions = 
streamMetadataList.stream().mapToInt(StreamMetadata::getNumPartitions).sum();

Review Comment:
   Updated this section to use an explicit loop for partition counting (instead 
of `stream().mapToInt(...).sum()`), per your suggestion for readability and 
lower overhead.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to