This is an automated email from the ASF dual-hosted git repository.

CRZbulabula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new e2322e9e9e5 [AINode] Remove Chronos2 DataLoader pin_memory option 
(#17822)
e2322e9e9e5 is described below

commit e2322e9e9e5d0ef5a4b982f27fa682c066e94cca
Author: Yongzao <[email protected]>
AuthorDate: Wed Jun 3 16:32:20 2026 +0800

    [AINode] Remove Chronos2 DataLoader pin_memory option (#17822)
---
 .../iotdb/ainode/it/AINodeSharedClusterIT.java     | 83 ++++++++++++++++++++++
 .../request_scheduler/basic_request_scheduler.py   | 19 +++--
 .../core/model/chronos2/pipeline_chronos2.py       |  1 -
 3 files changed, 91 insertions(+), 12 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
index 4ea2b4af41a..cbd0f62f16e 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java
@@ -50,6 +50,7 @@ import java.util.List;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
 import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
 import static 
org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
@@ -90,6 +91,10 @@ public class AINodeSharedClusterIT {
       "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")";
   private static final int DEFAULT_INPUT_LENGTH = 256;
   private static final int DEFAULT_OUTPUT_LENGTH = 48;
+  private static final int LOADED_MODEL_SMOKE_INPUT_LENGTH = 96;
+  private static final int LOADED_MODEL_SMOKE_OUTPUT_LENGTH = 1;
+  private static final List<String> LTSM_LOAD_DEVICE_COMBINATIONS =
+      Arrays.asList("cpu", "0", "cpu,0");
 
   private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
       "SELECT * FROM FORECAST("
@@ -438,6 +443,84 @@ public class AINodeSharedClusterIT {
 
   // ========== Concurrent forecast tests ==========
 
+  @Test
+  public void largeTimeSeriesModelLoadInferenceAndForecastTest()
+      throws SQLException, InterruptedException {
+    try (Connection treeConnection = 
EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+        Statement treeStatement = treeConnection.createStatement();
+        Connection tableConnection = 
EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+        Statement tableStatement = tableConnection.createStatement()) {
+      for (FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
+        for (String devices : LTSM_LOAD_DEVICE_COMBINATIONS) {
+          loadRunAndUnloadModelOnDevices(
+              treeStatement, tableStatement, modelInfo.getModelId(), devices);
+        }
+      }
+    }
+  }
+
+  private void loadRunAndUnloadModelOnDevices(
+      Statement treeStatement, Statement tableStatement, String modelId, 
String devices)
+      throws SQLException, InterruptedException {
+    boolean loadSubmitted = false;
+    try {
+      treeStatement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", 
modelId, devices));
+      loadSubmitted = true;
+      checkModelOnSpecifiedDevice(treeStatement, modelId, devices);
+      assertLoadedModelCallInferenceSucceeds(treeStatement, modelId);
+      assertLoadedModelForecastSucceeds(tableStatement, modelId);
+    } finally {
+      if (loadSubmitted) {
+        treeStatement.execute(String.format("UNLOAD MODEL %s FROM DEVICES 
'%s'", modelId, devices));
+        checkModelNotOnSpecifiedDevice(treeStatement, modelId, devices);
+      }
+    }
+  }
+
+  private void assertLoadedModelCallInferenceSucceeds(Statement statement, 
String modelId)
+      throws SQLException {
+    String callInferenceSQL =
+        String.format(
+            CALL_INFERENCE_SQL_TEMPLATE,
+            modelId,
+            0,
+            LOADED_MODEL_SMOKE_INPUT_LENGTH,
+            LOADED_MODEL_SMOKE_OUTPUT_LENGTH);
+    try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
+      ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
+      checkHeader(resultSetMetaData, "Time,output");
+      Assert.assertEquals(Types.DOUBLE, resultSetMetaData.getColumnType(2));
+      int count = 0;
+      while (resultSet.next()) {
+        resultSet.getDouble("output");
+        count++;
+      }
+      Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count);
+    }
+  }
+
+  private void assertLoadedModelForecastSucceeds(Statement statement, String 
modelId)
+      throws SQLException {
+    String forecastTableFunctionSQL =
+        String.format(
+            FORECAST_TABLE_FUNCTION_SQL_TEMPLATE,
+            modelId,
+            0,
+            5760,
+            LOADED_MODEL_SMOKE_INPUT_LENGTH,
+            5760,
+            LOADED_MODEL_SMOKE_OUTPUT_LENGTH,
+            1,
+            "time");
+    try (ResultSet resultSet = 
statement.executeQuery(forecastTableFunctionSQL)) {
+      int count = 0;
+      while (resultSet.next()) {
+        count++;
+      }
+      Assert.assertEquals(LOADED_MODEL_SMOKE_OUTPUT_LENGTH, count);
+    }
+  }
+
   @Test
   public void concurrentForecastTest() throws SQLException, 
InterruptedException {
     for (FakeModelInfo modelInfo : CONCURRENT_FORECAST_MODELS) {
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
 
b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
index ef5d37a18a1..2ee15fabca4 100644
--- 
a/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
+++ 
b/iotdb-core/ainode/iotdb/ainode/core/inference/request_scheduler/basic_request_scheduler.py
@@ -16,8 +16,6 @@
 # under the License.
 #
 
-import os
-
 import psutil
 import torch
 
@@ -53,23 +51,22 @@ class BasicRequestScheduler(AbstractRequestScheduler):
 
     def memory_is_available(self):
         if "cuda" in self.device.type:
-            used = torch.cuda.memory_allocated(self.device)
-            reserved = torch.cuda.memory_reserved(self.device)
+            available, total = torch.cuda.mem_get_info(self.device)
         elif "cpu" in self.device.type:
-            process = psutil.Process(os.getpid())
-            used = process.memory_info().rss
-            reserved = used
+            memory = psutil.virtual_memory()
+            available = memory.available
+            total = memory.total
         else:
-            used = 0
-            reserved = 0
             logger.warning(
                 f"[Inference] Unsupported device type: {self.device.type}. 
Memory checks will not be performed."
             )
+            return True
         logger.debug(
             f"[Inference][Device-{self.device}][Pool-{self.pool_id}] "
-            f"Memory used: {used/1024**2:.2f} MB, Max memory: 
{self.max_memory_bytes/1024**2:.2f} MB"
+            f"Memory available: {available/1024**2:.2f} MB, Total memory: 
{total/1024**2:.2f} MB, "
+            f"Required free memory: {self.max_memory_bytes/1024**2:.2f} MB"
         )
-        return used < self.max_memory_bytes
+        return available > self.max_memory_bytes
 
     def schedule_activate(self) -> list:
         requests = []
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
index 01ff78ba48d..083125e4b99 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py
@@ -309,7 +309,6 @@ class Chronos2Pipeline(ForecastPipeline):
         test_loader = DataLoader(
             test_dataset,
             batch_size=None,
-            pin_memory=True,
             shuffle=False,
             drop_last=False,
         )

Reply via email to