jerqi commented on code in PR #8450:
URL: https://github.com/apache/gravitino/pull/8450#discussion_r2343088577


##########
core/src/main/java/org/apache/gravitino/stats/storage/LancePartitionStatisticStorage.java:
##########
@@ -354,70 +446,101 @@ private static String getPartitionFilter(PartitionRange 
range) {
 
   private List<PersistedPartitionStatistics> listStatisticsImpl(
       Long tableId, String partitionFilter) {
-    String fileName = getFilePath(tableId);
-
-    try (Dataset dataset = open(fileName)) {
-
-      String filter = "table_id = " + tableId + partitionFilter;
-
-      try (LanceScanner scanner =
-          dataset.newScan(
-              new ScanOptions.Builder()
-                  .columns(
-                      Arrays.asList(
-                          TABLE_ID_COLUMN,
-                          PARTITION_NAME_COLUMN,
-                          STATISTIC_NAME_COLUMN,
-                          STATISTIC_VALUE_COLUMN,
-                          AUDIT_INFO_COLUMN))
-                  .withRowId(true)
-                  .batchSize(readBatchSize)
-                  .filter(filter)
-                  .build())) {
-        Map<String, List<PersistedStatistic>> partitionStatistics = 
Maps.newConcurrentMap();
-        try (ArrowReader reader = scanner.scanBatches()) {
-          while (reader.loadNextBatch()) {
-            VectorSchemaRoot root = reader.getVectorSchemaRoot();
-            List<FieldVector> fieldVectors = root.getFieldVectors();
-            VarCharVector partitionNameVector = (VarCharVector) 
fieldVectors.get(1);
-            VarCharVector statisticNameVector = (VarCharVector) 
fieldVectors.get(2);
-            LargeVarCharVector statisticValueVector = (LargeVarCharVector) 
fieldVectors.get(3);
-            VarCharVector auditInfoNameVector = (VarCharVector) 
fieldVectors.get(4);
-
-            for (int i = 0; i < root.getRowCount(); i++) {
-              String partitionName = new String(partitionNameVector.get(i), 
StandardCharsets.UTF_8);
-              String statisticName = new String(statisticNameVector.get(i), 
StandardCharsets.UTF_8);
-              String statisticValueStr =
-                  new String(statisticValueVector.get(i), 
StandardCharsets.UTF_8);
-              String auditInoStr = new String(auditInfoNameVector.get(i), 
StandardCharsets.UTF_8);
-
-              StatisticValue<?> statisticValue =
-                  JsonUtils.anyFieldMapper().readValue(statisticValueStr, 
StatisticValue.class);
-              AuditInfo auditInfo =
-                  JsonUtils.anyFieldMapper().readValue(auditInoStr, 
AuditInfo.class);
-
-              PersistedStatistic persistedStatistic =
-                  PersistedStatistic.of(statisticName, statisticValue, 
auditInfo);
-
-              partitionStatistics
-                  .computeIfAbsent(partitionName, k -> Lists.newArrayList())
-                  .add(persistedStatistic);
-            }
-          }
 
-          return partitionStatistics.entrySet().stream()
-              .map(entry -> PersistedPartitionStatistics.of(entry.getKey(), 
entry.getValue()))
-              .collect(Collectors.toList());
+    Dataset dataset = getDataset(tableId);
+
+    String filter = "table_id = " + tableId + partitionFilter;
+
+    try (LanceScanner scanner =
+        dataset.newScan(
+            new ScanOptions.Builder()
+                .columns(
+                    Arrays.asList(
+                        TABLE_ID_COLUMN,
+                        PARTITION_NAME_COLUMN,
+                        STATISTIC_NAME_COLUMN,
+                        STATISTIC_VALUE_COLUMN,
+                        AUDIT_INFO_COLUMN))
+                .withRowId(true)
+                .batchSize(readBatchSize)
+                .filter(filter)
+                .build())) {
+      Map<String, List<PersistedStatistic>> partitionStatistics = 
Maps.newConcurrentMap();
+      try (ArrowReader reader = scanner.scanBatches()) {
+        while (reader.loadNextBatch()) {
+          VectorSchemaRoot root = reader.getVectorSchemaRoot();
+          List<FieldVector> fieldVectors = root.getFieldVectors();
+          VarCharVector partitionNameVector = (VarCharVector) 
fieldVectors.get(1);
+          VarCharVector statisticNameVector = (VarCharVector) 
fieldVectors.get(2);
+          LargeVarCharVector statisticValueVector = (LargeVarCharVector) 
fieldVectors.get(3);
+          VarCharVector auditInfoNameVector = (VarCharVector) 
fieldVectors.get(4);
+
+          for (int i = 0; i < root.getRowCount(); i++) {
+            String partitionName = new String(partitionNameVector.get(i), 
StandardCharsets.UTF_8);
+            String statisticName = new String(statisticNameVector.get(i), 
StandardCharsets.UTF_8);
+            String statisticValueStr =
+                new String(statisticValueVector.get(i), 
StandardCharsets.UTF_8);
+            String auditInoStr = new String(auditInfoNameVector.get(i), 
StandardCharsets.UTF_8);
+
+            StatisticValue<?> statisticValue =
+                JsonUtils.anyFieldMapper().readValue(statisticValueStr, 
StatisticValue.class);
+            AuditInfo auditInfo =
+                JsonUtils.anyFieldMapper().readValue(auditInoStr, 
AuditInfo.class);
+
+            PersistedStatistic persistedStatistic =
+                PersistedStatistic.of(statisticName, statisticValue, 
auditInfo);
+
+            partitionStatistics
+                .computeIfAbsent(partitionName, k -> Lists.newArrayList())
+                .add(persistedStatistic);
+          }
         }
-      } catch (Exception e) {
-        throw new RuntimeException(e);
+
+        return partitionStatistics.entrySet().stream()
+            .map(entry -> PersistedPartitionStatistics.of(entry.getKey(), 
entry.getValue()))
+            .collect(Collectors.toList());
+      }
+    } catch (Exception e) {
+      throw new RuntimeException(e);
+    } finally {
+      if (!datasetCache.isPresent() && dataset != null) {
+        dataset.close();
       }
     }
   }
 
+  private Dataset getDataset(Long tableId) {
+    AtomicBoolean newlyCreated = new AtomicBoolean(false);
+    return datasetCache
+        .map(
+            cache -> {
+              Dataset cachedDataset =
+                  cache.get(
+                      tableId,
+                      id -> {
+                        newlyCreated.set(true);
+                        return open(getFilePath(id));
+                      });
+
+              // Ensure dataset uses the latest version
+              if (newlyCreated.get()) {

Review Comment:
   Yes, fixed.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to