This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 9e0316fe21e Fix concurrent last cache query bug because of one device 
in multi regions
9e0316fe21e is described below

commit 9e0316fe21edf9cde1874106cace0782caabcc3f
Author: Weihao Li <[email protected]>
AuthorDate: Thu Sep 18 18:32:44 2025 +0800

    Fix concurrent last cache query bug because of one device in multi regions
---
 .../it/query/recent/IoTDBTableAggregation2IT.java  |  1 -
 .../recent/IoTDBTableAggregationNonStream2IT.java  |  1 -
 .../execution/fragment/DataNodeQueryContext.java   | 68 ++++++++++++++-
 .../relational/LastQueryAggTableScanOperator.java  | 99 ++++++++++++++--------
 .../plan/planner/TableOperatorGenerator.java       | 60 +++++++++++--
 .../distribute/TableDistributedPlanGenerator.java  |  3 +
 .../distribute/TableModelQueryFragmentPlanner.java | 42 +++++++++
 .../planner/node/AggregationTableScanNode.java     | 62 ++++++++++++++
 8 files changed, 290 insertions(+), 46 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregation2IT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregation2IT.java
index 76d861fd85e..04538a55a84 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregation2IT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregation2IT.java
@@ -39,7 +39,6 @@ public class IoTDBTableAggregation2IT extends 
IoTDBTableAggregationIT {
     
EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 * 
1024);
     
EnvFactory.getEnv().getConfig().getCommonConfig().setTimePartitionInterval(5_000);
     
EnvFactory.getEnv().getConfig().getCommonConfig().setDataPartitionAllocationStrategy(SHUFFLE);
-    
EnvFactory.getEnv().getConfig().getCommonConfig().setEnableLastCache(false);
     EnvFactory.getEnv().initClusterEnvironment();
     prepareTableData(createSqls);
   }
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationNonStream2IT.java
 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationNonStream2IT.java
index a430979eaee..27cd9ad89c5 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationNonStream2IT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationNonStream2IT.java
@@ -39,7 +39,6 @@ public class IoTDBTableAggregationNonStream2IT extends 
IoTDBTableAggregationNonS
     
EnvFactory.getEnv().getConfig().getCommonConfig().setMaxTsBlockSizeInByte(4 * 
1024);
     
EnvFactory.getEnv().getConfig().getCommonConfig().setTimePartitionInterval(5_000);
     
EnvFactory.getEnv().getConfig().getCommonConfig().setDataPartitionAllocationStrategy(SHUFFLE);
-    
EnvFactory.getEnv().getConfig().getCommonConfig().setEnableLastCache(false);
     EnvFactory.getEnv().initClusterEnvironment();
     String original = createSqls[2];
     // make 'province', 'city', 'region' be FIELD to cover cases using 
GroupedAccumulator
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/DataNodeQueryContext.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/DataNodeQueryContext.java
index ffa3ead32e1..814fcc7df63 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/DataNodeQueryContext.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/fragment/DataNodeQueryContext.java
@@ -20,22 +20,43 @@
 package org.apache.iotdb.db.queryengine.execution.fragment;
 
 import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
+import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName;
+import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.cache.TableDeviceSchemaCache;
 
 import org.apache.tsfile.read.TimeValuePair;
 import org.apache.tsfile.utils.Pair;
 
 import javax.annotation.concurrent.GuardedBy;
 
+import java.util.HashMap;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.locks.ReentrantLock;
 
+import static com.google.common.base.Preconditions.checkArgument;
+
 public class DataNodeQueryContext {
-  // left of Pair is DataNodeSeriesScanNum, right of Pair is the last value 
waiting to be updated
+  // Used for TreeModel, left of Pair is DataNodeSeriesScanNum, right of Pair 
is the last value
+  // waiting to be updated
   @GuardedBy("lock")
   private final Map<PartialPath, Pair<AtomicInteger, TimeValuePair>> 
uncachedPathToSeriesScanInfo;
 
+  // Used for TableModel
+  // 1. Outer Map: record the info for each Table to make sure DeviceEntry is 
unique in the value
+  // Scope.
+  // 2. Inner Map: record DeviceEntry to last cache for each measurement, left 
of Pair is the
+  // count of device regions, right is the measurement values wait to be 
updated for last cache.
+  // Notice: only the device counts more than one will be recorded
+  @GuardedBy("lock")
+  private final Map<
+          QualifiedObjectName, Map<DeviceEntry, Pair<Integer, Map<String, 
TimeValuePair>>>>
+      deviceCountAndMeasurementValues;
+
+  private static final TableDeviceSchemaCache TABLE_DEVICE_SCHEMA_CACHE =
+      TableDeviceSchemaCache.getInstance();
+
   private final AtomicInteger dataNodeFINum;
 
   // TODO consider more fine-grained locks, now the AtomicInteger in 
uncachedPathToSeriesScanInfo is
@@ -45,6 +66,7 @@ public class DataNodeQueryContext {
   public DataNodeQueryContext(int dataNodeFINum) {
     this.uncachedPathToSeriesScanInfo = new ConcurrentHashMap<>();
     this.dataNodeFINum = new AtomicInteger(dataNodeFINum);
+    this.deviceCountAndMeasurementValues = new HashMap<>();
   }
 
   public boolean unCached(PartialPath path) {
@@ -55,6 +77,50 @@ public class DataNodeQueryContext {
     uncachedPathToSeriesScanInfo.put(path, new Pair<>(dataNodeSeriesScanNum, 
null));
   }
 
+  public void decreaseDeviceAndMayUpdateLastCache(
+      QualifiedObjectName tableName, DeviceEntry deviceEntry, Integer 
initialCount) {
+    checkArgument(initialCount != null, "initialCount shouldn't be null here");
+
+    Map<DeviceEntry, Pair<Integer, Map<String, TimeValuePair>>> deviceInfo =
+        deviceCountAndMeasurementValues.computeIfAbsent(tableName, t -> new 
HashMap<>());
+
+    Pair<Integer, Map<String, TimeValuePair>> info =
+        deviceInfo.computeIfAbsent(deviceEntry, d -> new Pair<>(initialCount, 
new HashMap<>()));
+    info.left--;
+    if (info.left == 0) {
+      updateLastCache(tableName, deviceEntry);
+    }
+  }
+
+  public void addUnCachedDeviceIfAbsent(
+      QualifiedObjectName tableName, DeviceEntry deviceEntry, Integer count) {
+    checkArgument(count != null, "count shouldn't be null here");
+
+    Map<DeviceEntry, Pair<Integer, Map<String, TimeValuePair>>> deviceInfo =
+        deviceCountAndMeasurementValues.computeIfAbsent(tableName, t -> new 
HashMap<>());
+
+    deviceInfo.putIfAbsent(deviceEntry, new Pair<>(count, new HashMap<>()));
+  }
+
+  public Pair<Integer, Map<String, TimeValuePair>> getDeviceInfo(
+      QualifiedObjectName tableName, DeviceEntry deviceEntry) {
+    return deviceCountAndMeasurementValues.get(tableName).get(deviceEntry);
+  }
+
+  /** Update the last cache when device count decrease to zero. */
+  public void updateLastCache(QualifiedObjectName tableName, DeviceEntry 
deviceEntry) {
+    Map<String, TimeValuePair> values =
+        
deviceCountAndMeasurementValues.get(tableName).get(deviceEntry).getRight();
+    // if a device hits cache each time, the values recorded in context will 
be null
+    if (values != null) {
+      TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
+          tableName.getDatabaseName(),
+          deviceEntry.getDeviceID(),
+          values.keySet().toArray(new String[0]),
+          values.values().toArray(new TimeValuePair[0]));
+    }
+  }
+
   public Pair<AtomicInteger, TimeValuePair> getSeriesScanInfo(PartialPath 
path) {
     return uncachedPathToSeriesScanInfo.get(path);
   }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/LastQueryAggTableScanOperator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/LastQueryAggTableScanOperator.java
index e734145e470..fffd43043d3 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/LastQueryAggTableScanOperator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/LastQueryAggTableScanOperator.java
@@ -21,6 +21,7 @@ package 
org.apache.iotdb.db.queryengine.execution.operator.source.relational;
 
 import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
 import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
+import org.apache.iotdb.db.queryengine.execution.fragment.DataNodeQueryContext;
 import 
org.apache.iotdb.db.queryengine.execution.operator.process.last.LastQueryUtil;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastAccumulator;
 import 
org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.LastByDescAccumulator;
@@ -44,6 +45,7 @@ import org.apache.tsfile.write.UnSupportedDataTypeException;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.OptionalLong;
 import java.util.concurrent.TimeUnit;
@@ -65,6 +67,7 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
   private static final TableDeviceSchemaCache TABLE_DEVICE_SCHEMA_CACHE =
       TableDeviceSchemaCache.getInstance();
 
+  private final QualifiedObjectName tableCompleteName;
   private final String dbName;
   private int outputDeviceIndex;
   private DeviceEntry currentDeviceEntry;
@@ -81,13 +84,18 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
   // indicates the index of last(time) aggregation
   private int lastTimeAggregationIdx = -1;
 
+  private final Map<DeviceEntry, Integer> deviceCountMap;
+  private final DataNodeQueryContext dataNodeQueryContext;
+
   public LastQueryAggTableScanOperator(
       AbstractAggTableScanOperatorParameter parameter,
       List<DeviceEntry> cachedDeviceEntries,
       QualifiedObjectName qualifiedObjectName,
       List<Integer> hitCachesIndexes,
       List<Pair<OptionalLong, TsPrimitiveType[]>> lastRowCacheResults,
-      List<TimeValuePair[]> lastValuesCacheResults) {
+      List<TimeValuePair[]> lastValuesCacheResults,
+      Map<DeviceEntry, Integer> deviceCountMap,
+      DataNodeQueryContext dataNodeQueryContext) {
 
     super(parameter);
 
@@ -101,6 +109,7 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
     this.hitCachesIndexes = hitCachesIndexes;
     this.lastRowCacheResults = lastRowCacheResults;
     this.lastValuesCacheResults = lastValuesCacheResults;
+    this.tableCompleteName = qualifiedObjectName;
     this.dbName = qualifiedObjectName.getDatabaseName();
 
     this.operatorContext.recordSpecifiedInfo(
@@ -110,6 +119,8 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
         lastTimeAggregationIdx = i;
       }
     }
+    this.deviceCountMap = deviceCountMap;
+    this.dataNodeQueryContext = dataNodeQueryContext;
   }
 
   @Override
@@ -518,28 +529,23 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
         case TIME:
           if (!hasSetLastTime) {
             hasSetLastTime = true;
+            updateMeasurementList.add("");
             if (i == lastTimeAggregationIdx) {
               LastDescAccumulator lastAccumulator =
                   (LastDescAccumulator) tableAggregator.getAccumulator();
               if (lastAccumulator.hasInitResult()) {
-                updateMeasurementList.add("");
                 updateTimeValuePairList.add(
                     new TimeValuePair(
                         lastAccumulator.getMaxTime(),
                         new 
TsPrimitiveType.TsLong(lastAccumulator.getMaxTime())));
               } else {
                 currentDeviceEntry = deviceEntries.get(currentDeviceIndex);
-                TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
-                    dbName,
-                    currentDeviceEntry.getDeviceID(),
-                    new String[] {""},
-                    new TimeValuePair[] {EMPTY_TIME_VALUE_PAIR});
+                updateTimeValuePairList.add(EMPTY_TIME_VALUE_PAIR);
               }
             } else {
               LastByDescAccumulator lastByAccumulator =
                   (LastByDescAccumulator) tableAggregator.getAccumulator();
               if (lastByAccumulator.hasInitResult() && 
!lastByAccumulator.isXNull()) {
-                updateMeasurementList.add("");
                 updateTimeValuePairList.add(
                     new TimeValuePair(
                         lastByAccumulator.getLastTimeOfY(),
@@ -551,7 +557,7 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
         case FIELD:
           LastByDescAccumulator lastByAccumulator =
               (LastByDescAccumulator) tableAggregator.getAccumulator();
-          // only can update LastCache when last_by return non-null value
+          updateMeasurementList.add(schema.getName());
           if (lastByAccumulator.hasInitResult() && 
!lastByAccumulator.isXNull()) {
             long lastByTime = lastByAccumulator.getLastTimeOfY();
 
@@ -562,10 +568,11 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
                   new TimeValuePair(lastByTime, new 
TsPrimitiveType.TsLong(lastByTime)));
             }
 
-            updateMeasurementList.add(schema.getName());
             updateTimeValuePairList.add(
                 new TimeValuePair(
                     lastByTime, 
cloneTsPrimitiveType(lastByAccumulator.getXResult())));
+          } else {
+            updateTimeValuePairList.add(EMPTY_TIME_VALUE_PAIR);
           }
           break;
         default:
@@ -575,17 +582,7 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
       channel += tableAggregator.getChannelCount();
     }
 
-    if (!updateMeasurementList.isEmpty()) {
-      String[] updateMeasurementArray = updateMeasurementList.toArray(new 
String[0]);
-      TimeValuePair[] updateTimeValuePairArray =
-          updateTimeValuePairList.toArray(new TimeValuePair[0]);
-      currentDeviceEntry = deviceEntries.get(currentDeviceIndex);
-      TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
-          dbName,
-          currentDeviceEntry.getDeviceID(),
-          updateMeasurementArray,
-          updateTimeValuePairArray);
-    }
+    checkIfUpdated(updateMeasurementList, updateTimeValuePairList);
   }
 
   private void updateLastCacheUseLastValuesIfPossible() {
@@ -604,19 +601,15 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
             hasSetLastTime = true;
             LastDescAccumulator lastAccumulator =
                 (LastDescAccumulator) tableAggregator.getAccumulator();
+            updateMeasurementList.add("");
             if (lastAccumulator.hasInitResult()) {
-              updateMeasurementList.add("");
               updateTimeValuePairList.add(
                   new TimeValuePair(
                       lastAccumulator.getMaxTime(),
                       new 
TsPrimitiveType.TsLong(lastAccumulator.getMaxTime())));
             } else {
               currentDeviceEntry = deviceEntries.get(currentDeviceIndex);
-              TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
-                  dbName,
-                  currentDeviceEntry.getDeviceID(),
-                  new String[] {""},
-                  new TimeValuePair[] {EMPTY_TIME_VALUE_PAIR});
+              updateTimeValuePairList.add(EMPTY_TIME_VALUE_PAIR);
             }
           }
           break;
@@ -643,16 +636,52 @@ public class LastQueryAggTableScanOperator extends 
AbstractAggTableScanOperator
       channel += tableAggregator.getChannelCount();
     }
 
+    checkIfUpdated(updateMeasurementList, updateTimeValuePairList);
+  }
+
+  private void checkIfUpdated(
+      List<String> updateMeasurementList, List<TimeValuePair> 
updateTimeValuePairList) {
     if (!updateMeasurementList.isEmpty()) {
-      String[] updateMeasurementArray = updateMeasurementList.toArray(new 
String[0]);
-      TimeValuePair[] updateTimeValuePairArray =
-          updateTimeValuePairList.toArray(new TimeValuePair[0]);
       currentDeviceEntry = deviceEntries.get(currentDeviceIndex);
-      TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
-          dbName,
-          currentDeviceEntry.getDeviceID(),
-          updateMeasurementArray,
-          updateTimeValuePairArray);
+
+      boolean deviceInMultiRegion =
+          deviceCountMap != null && 
deviceCountMap.containsKey(currentDeviceEntry);
+      if (!deviceInMultiRegion) {
+        TABLE_DEVICE_SCHEMA_CACHE.updateLastCacheIfExists(
+            dbName,
+            currentDeviceEntry.getDeviceID(),
+            updateMeasurementList.toArray(new String[0]),
+            updateTimeValuePairList.toArray(new TimeValuePair[0]));
+        return;
+      }
+
+      dataNodeQueryContext.lock(true);
+      try {
+        Pair<Integer, Map<String, TimeValuePair>> deviceInfo =
+            dataNodeQueryContext.getDeviceInfo(tableCompleteName, 
currentDeviceEntry);
+        Map<String, TimeValuePair> values = deviceInfo.getRight();
+
+        int size = updateMeasurementList.size();
+        for (int i = 0; i < size; i++) {
+          String measurementName = updateMeasurementList.get(i);
+          TimeValuePair timeValuePair = updateTimeValuePairList.get(i);
+          if (values.containsKey(measurementName)) {
+            TimeValuePair oldValue = values.get(measurementName);
+            if (timeValuePair.getTimestamp() > oldValue.getTimestamp()) {
+              values.put(measurementName, timeValuePair);
+            }
+          } else {
+            values.put(measurementName, timeValuePair);
+          }
+        }
+
+        deviceInfo.left--;
+        if (deviceInfo.left == 0) {
+          dataNodeQueryContext.updateLastCache(tableCompleteName, 
currentDeviceEntry);
+        }
+      } finally {
+        dataNodeQueryContext.unLock(true);
+      }
     }
   }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 08be6377cb7..c33103f4ac2 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -3094,8 +3094,8 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
           allHitCache = false;
         }
 
+        DeviceEntry deviceEntry = node.getDeviceEntries().get(i);
         if (!allHitCache) {
-          DeviceEntry deviceEntry = node.getDeviceEntries().get(i);
           AlignedFullPath alignedPath =
               constructAlignedPath(
                   deviceEntry,
@@ -3104,6 +3104,7 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
                   parameter.getAllSensors());
           ((DataDriverContext) 
context.getDriverContext()).addPath(alignedPath);
           unCachedDeviceEntries.add(deviceEntry);
+          addUncachedDeviceToContext(node, context, deviceEntry);
 
           // last cache updateColumns need to put "" as time column
           String[] updateColumns = new 
String[parameter.getMeasurementColumnNames().size() + 1];
@@ -3120,7 +3121,8 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
         } else {
           hitCachesIndexes.add(i);
           lastRowCacheResults.add(lastByResult.get());
-          cachedDeviceEntries.add(node.getDeviceEntries().get(i));
+          cachedDeviceEntries.add(deviceEntry);
+          decreaseDeviceCount(node, context, deviceEntry);
         }
       }
     } else {
@@ -3185,8 +3187,8 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
           allHitCache = false;
         }
 
+        DeviceEntry deviceEntry = node.getDeviceEntries().get(i);
         if (!allHitCache) {
-          DeviceEntry deviceEntry = node.getDeviceEntries().get(i);
           AlignedFullPath alignedPath =
               constructAlignedPath(
                   deviceEntry,
@@ -3195,19 +3197,21 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
                   parameter.getAllSensors());
           ((DataDriverContext) 
context.getDriverContext()).addPath(alignedPath);
           unCachedDeviceEntries.add(deviceEntry);
+          addUncachedDeviceToContext(node, context, deviceEntry);
 
           TableDeviceSchemaCache.getInstance()
               .initOrInvalidateLastCache(
                   node.getQualifiedObjectName().getDatabaseName(),
                   deviceEntry.getDeviceID(),
-                  needInitTime
-                      ? targetColumns
-                      : Arrays.copyOfRange(targetColumns, 0, 
targetColumns.length - 1),
+                  needInitTime && node.getGroupingKeys().isEmpty()
+                      ? Arrays.copyOfRange(targetColumns, 0, 
targetColumns.length - 1)
+                      : targetColumns,
                   false);
         } else {
           hitCachesIndexes.add(i);
           lastValuesCacheResults.add(lastResult);
-          cachedDeviceEntries.add(node.getDeviceEntries().get(i));
+          cachedDeviceEntries.add(deviceEntry);
+          decreaseDeviceCount(node, context, deviceEntry);
         }
       }
     }
@@ -3222,7 +3226,9 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
             node.getQualifiedObjectName(),
             hitCachesIndexes,
             lastRowCacheResults,
-            lastValuesCacheResults);
+            lastValuesCacheResults,
+            node.getDeviceCountMap(),
+            context.getInstanceContext().getDataNodeQueryContext());
 
     ((DataDriverContext) 
context.getDriverContext()).addSourceOperator(lastQueryOperator);
     parameter
@@ -3231,6 +3237,44 @@ public class TableOperatorGenerator extends 
PlanVisitor<Operator, LocalExecution
     return lastQueryOperator;
   }
 
+  private void addUncachedDeviceToContext(
+      AggregationTableScanNode node, LocalExecutionPlanContext context, 
DeviceEntry deviceEntry) {
+    boolean deviceInMultiRegion =
+        node.getDeviceCountMap() != null && 
node.getDeviceCountMap().containsKey(deviceEntry);
+    if (!deviceInMultiRegion) {
+      return;
+    }
+
+    context.dataNodeQueryContext.lock(true);
+    try {
+      context.dataNodeQueryContext.addUnCachedDeviceIfAbsent(
+          node.getQualifiedObjectName(), deviceEntry, 
node.getDeviceCountMap().get(deviceEntry));
+    } finally {
+      context.dataNodeQueryContext.unLock(true);
+    }
+  }
+
+  /**
+   * Decrease the device count when its last cache was hit. Notice that the 
count can also be zero
+   * after decrease, we need to update last cache if needed.
+   */
+  private void decreaseDeviceCount(
+      AggregationTableScanNode node, LocalExecutionPlanContext context, 
DeviceEntry deviceEntry) {
+    boolean deviceInMultiRegion =
+        node.getDeviceCountMap() != null && 
node.getDeviceCountMap().containsKey(deviceEntry);
+    if (!deviceInMultiRegion) {
+      return;
+    }
+
+    context.dataNodeQueryContext.lock(true);
+    try {
+      context.dataNodeQueryContext.decreaseDeviceAndMayUpdateLastCache(
+          node.getQualifiedObjectName(), deviceEntry, 
node.getDeviceCountMap().get(deviceEntry));
+    } finally {
+      context.dataNodeQueryContext.unLock(true);
+    }
+  }
+
   private SeriesScanOptions buildSeriesScanOptions(
       LocalExecutionPlanContext context,
       Map<Symbol, ColumnSchema> columnSchemaMap,
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
index fb69e66f0db..5626f6865cb 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
@@ -140,6 +140,7 @@ import static 
org.apache.tsfile.utils.Preconditions.checkArgument;
 /** This class is used to generate distributed plan for table model. */
 public class TableDistributedPlanGenerator
     extends PlanVisitor<List<PlanNode>, 
TableDistributedPlanGenerator.PlanContext> {
+  private final MPPQueryContext queryContext;
   private final QueryId queryId;
   private final Analysis analysis;
   private final SymbolAllocator symbolAllocator;
@@ -152,6 +153,7 @@ public class TableDistributedPlanGenerator
       final Analysis analysis,
       final SymbolAllocator symbolAllocator,
       final DataNodeLocationSupplierFactory.DataNodeLocationSupplier 
dataNodeLocationSupplier) {
+    this.queryContext = queryContext;
     this.queryId = queryContext.getQueryId();
     this.analysis = analysis;
     this.symbolAllocator = symbolAllocator;
@@ -1146,6 +1148,7 @@ public class TableDistributedPlanGenerator
         if (regionReplicaSets.size() > 1) {
           needSplit = true;
           context.deviceCrossRegion = true;
+          
queryContext.setNeedUpdateScanNumForLastQuery(node.mayUseLastCache());
         }
         regionReplicaSetsList.add(regionReplicaSets);
       }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableModelQueryFragmentPlanner.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableModelQueryFragmentPlanner.java
index 8f7c2aed732..40ed57a7fa2 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableModelQueryFragmentPlanner.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableModelQueryFragmentPlanner.java
@@ -34,6 +34,9 @@ import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
 import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId;
 import 
org.apache.iotdb.db.queryengine.plan.planner.plan.node.sink.MultiChildrenSinkNode;
 import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis;
+import org.apache.iotdb.db.queryengine.plan.relational.metadata.DeviceEntry;
+import 
org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName;
+import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode;
 import 
org.apache.iotdb.db.queryengine.plan.relational.planner.node.ExchangeNode;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CountDevice;
 import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ShowDevice;
@@ -95,6 +98,45 @@ public class TableModelQueryFragmentPlanner extends 
AbstractFragmentParallelPlan
 
     fragmentInstanceList.forEach(
         fi -> 
fi.setDataNodeFINum(dataNodeFIMap.get(fi.getHostDataNode()).size()));
+
+    if (queryContext.needUpdateScanNumForLastQuery()) {
+      dataNodeFIMap
+          .values()
+          .forEach(
+              fragmentInstances -> {
+                Map<QualifiedObjectName, Map<DeviceEntry, Integer>> 
deviceCountMapOfEachTable =
+                    new HashMap<>();
+                fragmentInstances.forEach(
+                    fragmentInstance ->
+                        updateScanNum(
+                            fragmentInstance.getFragment().getPlanNodeTree(),
+                            deviceCountMapOfEachTable));
+
+                // For less size of serde, remove the device which the region 
count is 1
+                deviceCountMapOfEachTable
+                    .values()
+                    .forEach(deviceMap -> deviceMap.entrySet().removeIf(v -> 
v.getValue() == 1));
+              });
+    }
+  }
+
+  private void updateScanNum(
+      PlanNode planNode,
+      Map<QualifiedObjectName, Map<DeviceEntry, Integer>> 
deviceCountMapOfEachTable) {
+    if (planNode instanceof AggregationTableScanNode) {
+      AggregationTableScanNode aggregationTableScanNode = 
(AggregationTableScanNode) planNode;
+      Map<DeviceEntry, Integer> deviceMap =
+          deviceCountMapOfEachTable.computeIfAbsent(
+              aggregationTableScanNode.getQualifiedObjectName(), name -> new 
HashMap<>());
+
+      aggregationTableScanNode
+          .getDeviceEntries()
+          .forEach(deviceEntry -> deviceMap.merge(deviceEntry, 1, 
Integer::sum));
+      // Each AggTableScanNode with the same complete tableName in this 
DataNode holds this map
+      aggregationTableScanNode.setDeviceCountMap(deviceMap);
+      return;
+    }
+    planNode.getChildren().forEach(node -> updateScanNum(node, 
deviceCountMapOfEachTable));
   }
 
   private void recordPlanNodeRelation(PlanNode root, PlanFragmentId 
planFragmentId) {
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
index b56d76db6c1..b6840d7200a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/AggregationTableScanNode.java
@@ -45,6 +45,7 @@ import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -54,6 +55,8 @@ import java.util.Set;
 
 import static com.google.common.base.Preconditions.checkArgument;
 import static java.util.Objects.requireNonNull;
+import static 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction.LAST;
+import static 
org.apache.iotdb.commons.udf.builtin.relational.TableBuiltinAggregationFunction.LAST_BY;
 import static 
org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator.DATE_BIN_PREFIX;
 import static org.apache.iotdb.db.utils.constant.SqlConstant.COUNT;
 import static 
org.apache.iotdb.db.utils.constant.SqlConstant.TABLE_TIME_COLUMN_NAME;
@@ -68,6 +71,8 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
   protected AggregationNode.Step step;
   protected Optional<Symbol> groupIdSymbol;
 
+  private Map<DeviceEntry, Integer> deviceCountMap;
+
   public AggregationTableScanNode(
       PlanNodeId id,
       QualifiedObjectName qualifiedObjectName,
@@ -388,6 +393,23 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
         aggregationNode.getGroupIdSymbol());
   }
 
+  public boolean mayUseLastCache() {
+    // Only made a simple judgment is here, if Aggregations is not empty and 
all of them are LAST or
+    // LAST_BY
+    if (aggregations.isEmpty()) {
+      return false;
+    }
+
+    for (AggregationNode.Aggregation aggregation : aggregations.values()) {
+      String functionName = 
aggregation.getResolvedFunction().getSignature().getName();
+      if (!LAST_BY.getFunctionName().equals(functionName)
+          && !LAST.getFunctionName().equals(functionName)) {
+        return false;
+      }
+    }
+    return true;
+  }
+
   @Override
   public boolean equals(Object o) {
     if (this == o) {
@@ -456,6 +478,17 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
     if (node.groupIdSymbol.isPresent()) {
       Symbol.serialize(node.groupIdSymbol.get(), byteBuffer);
     }
+
+    if (node.deviceCountMap != null) {
+      ReadWriteIOUtils.write(true, byteBuffer);
+      ReadWriteIOUtils.write(node.deviceCountMap.size(), byteBuffer);
+      for (Map.Entry<DeviceEntry, Integer> entry : 
node.deviceCountMap.entrySet()) {
+        entry.getKey().serialize(byteBuffer);
+        ReadWriteIOUtils.write(entry.getValue(), byteBuffer);
+      }
+    } else {
+      ReadWriteIOUtils.write(false, byteBuffer);
+    }
   }
 
   protected static void serializeMemberVariables(
@@ -493,6 +526,17 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
     if (node.groupIdSymbol.isPresent()) {
       Symbol.serialize(node.groupIdSymbol.get(), stream);
     }
+
+    if (node.deviceCountMap != null) {
+      ReadWriteIOUtils.write(true, stream);
+      ReadWriteIOUtils.write(node.deviceCountMap.size(), stream);
+      for (Map.Entry<DeviceEntry, Integer> entry : 
node.deviceCountMap.entrySet()) {
+        entry.getKey().serialize(stream);
+        ReadWriteIOUtils.write(entry.getValue(), stream);
+      }
+    } else {
+      ReadWriteIOUtils.write(false, stream);
+    }
   }
 
   protected static void deserializeMemberVariables(
@@ -538,6 +582,16 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
     node.groupIdSymbol = groupIdSymbol;
 
     node.outputSymbols = constructOutputSymbols(node.getGroupingSets(), 
node.getAggregations());
+
+    if (ReadWriteIOUtils.readBool(byteBuffer)) {
+      size = ReadWriteIOUtils.readInt(byteBuffer);
+      Map<DeviceEntry, Integer> deviceRegionCountMap = new HashMap<>(size);
+      while (size-- > 0) {
+        DeviceEntry deviceEntry = DeviceEntry.deserialize(byteBuffer);
+        deviceRegionCountMap.put(deviceEntry, 
ReadWriteIOUtils.readInt(byteBuffer));
+      }
+      node.setDeviceCountMap(deviceRegionCountMap);
+    }
   }
 
   @Override
@@ -561,4 +615,12 @@ public class AggregationTableScanNode extends 
DeviceTableScanNode {
     node.setPlanNodeId(PlanNodeId.deserialize(byteBuffer));
     return node;
   }
+
+  public void setDeviceCountMap(Map<DeviceEntry, Integer> deviceCountMap) {
+    this.deviceCountMap = deviceCountMap;
+  }
+
+  public Map<DeviceEntry, Integer> getDeviceCountMap() {
+    return deviceCountMap;
+  }
 }


Reply via email to