This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d5b7c50809 [AINode] Enhance the robustness of AINode (#15695)
7d5b7c50809 is described below

commit 7d5b7c508092acebd847c2daf4d1335ad6c68a0c
Author: Yongzao <[email protected]>
AuthorDate: Thu Jun 12 18:57:41 2025 +0800

    [AINode] Enhance the robustness of AINode (#15695)
---
 .../org/apache/iotdb/ainode/it/AINodeBasicIT.java  |  8 +--
 iotdb-core/ainode/ainode/core/constant.py          |  4 +-
 iotdb-core/ainode/ainode/core/ingress/dataset.py   | 32 +--------
 iotdb-core/ainode/ainode/core/ingress/iotdb.py     | 84 ++++++++++++++++------
 .../iotdb/confignode/manager/ConfigManager.java    | 12 ++--
 .../execution/config/TableConfigTaskVisitor.java   |  4 +-
 .../config/metadata/ai/CreateTrainingTask.java     |  6 --
 .../iotdb/commons/client/ainode/AINodeClient.java  |  4 +-
 .../iotdb/commons/client/ainode/AINodeInfo.java    | 29 --------
 .../apache/iotdb/commons/conf/CommonConfig.java    | 12 ----
 .../src/main/thrift/confignode.thrift              |  1 -
 11 files changed, 80 insertions(+), 116 deletions(-)

diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
index b503d67f814..da4f8e9c536 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
@@ -20,14 +20,10 @@
 package org.apache.iotdb.ainode.it;
 
 import org.apache.iotdb.it.env.EnvFactory;
-import org.apache.iotdb.it.framework.IoTDBTestRunner;
-import org.apache.iotdb.itbase.category.AIClusterIT;
 
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Test;
-import org.junit.experimental.categories.Category;
-import org.junit.runner.RunWith;
 
 import java.io.File;
 import java.sql.Connection;
@@ -40,8 +36,8 @@ import static 
org.apache.iotdb.db.it.utils.TestUtils.prepareData;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
 
-@RunWith(IoTDBTestRunner.class)
-@Category({AIClusterIT.class})
+// @RunWith(IoTDBTestRunner.class)
+// @Category({AIClusterIT.class})
 public class AINodeBasicIT {
   static final String MODEL_PATH =
       System.getProperty("user.dir")
diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index 24d13a12ab8..98a66abc019 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -53,8 +53,8 @@ AINODE_LOG_FILE_LEVELS = [logging.DEBUG, logging.INFO, 
logging.WARNING, logging.
 
 TRIAL_ID_PREFIX = "__trial_"
 DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
-DEFAULT_MODEL_FILE_NAME = "model.pt"
-DEFAULT_CONFIG_FILE_NAME = "config.yaml"
+DEFAULT_MODEL_FILE_NAME = "model.safetensors"
+DEFAULT_CONFIG_FILE_NAME = "config.json"
 DEFAULT_CHUNK_SIZE = 8192
 
 DEFAULT_RECONNECT_TIMEOUT = 20
diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py 
b/iotdb-core/ainode/ainode/core/ingress/dataset.py
index c2410ed4374..9783c6c85c1 100644
--- a/iotdb-core/ainode/ainode/core/ingress/dataset.py
+++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py
@@ -15,10 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from torch.utils.data import Dataset
 
-from ainode.core.ingress.iotdb import IoTDBTableModelDataset, 
IoTDBTreeModelDataset
-from ainode.core.util.decorator import singleton
+from torch.utils.data import Dataset
 
 
 class BasicDatabaseDataset(Dataset):
@@ -32,31 +30,3 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
         super().__init__(ip, port)
         self.input_len = input_len
         self.output_len = output_len
-
-
-def register_dataset(key: str, dataset: Dataset):
-    DatasetFactory().register(key, dataset)
-
-
-@singleton
-class DatasetFactory(object):
-
-    def __init__(self):
-        self.dataset_list = {
-            "iotdb.table": IoTDBTableModelDataset,
-            "iotdb.tree": IoTDBTreeModelDataset,
-        }
-
-    def register(self, key: str, dataset: Dataset):
-        if key not in self.dataset_list:
-            self.dataset_list[key] = dataset
-        else:
-            raise KeyError(f"Dataset {key} already exists")
-
-    def deregister(self, key: str):
-        del self.dataset_list[key]
-
-    def get_dataset(self, key: str):
-        if key not in self.dataset_list.keys():
-            raise KeyError(f"Dataset {key} does not exist")
-        return self.dataset_list[key]
diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py 
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index 4b034ac8808..d1b344b43a8 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -15,16 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+import numpy as np
 import torch
 from iotdb.Session import Session
 from iotdb.table_session import TableSession, TableSessionConfig
 from iotdb.utils.Field import Field
 from iotdb.utils.IoTDBConstants import TSDataType
-from util.cache import MemoryLRUCache
+from torch.utils.data import Dataset
 
 from ainode.core.config import AINodeDescriptor
 from ainode.core.ingress.dataset import BasicDatabaseForecastDataset
 from ainode.core.log import Logger
+from ainode.core.util.cache import MemoryLRUCache
+from ainode.core.util.decorator import singleton
 
 logger = Logger()
 
@@ -55,7 +58,7 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
         model_id: str,
         input_len: int,
         out_len: int,
-        schema_list: list,
+        data_schema_list: list,
         ip: str = "127.0.0.1",
         port: int = 6667,
         username: str = "root",
@@ -81,15 +84,16 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
         )
         self.session.open(False)
         self.context_length = self.input_len + self.output_len
-        self._fetch_schema(schema_list)
+        self.token_num = self.context_length // self.input_len
+        self._fetch_schema(data_schema_list)
         self.start_idx = int(self.total_count * start_split)
         self.end_idx = int(self.total_count * end_split)
         self.cache_enable = _cache_enable()
         self.cache_key_prefix = model_id + "_"
 
-    def _fetch_schema(self, schema_list: list):
+    def _fetch_schema(self, data_schema_list: list):
         series_to_length = {}
-        for schema in schema_list:
+        for schema in data_schema_list:
             path_pattern = schema.schemaName
             series_list = []
             time_condition = (
@@ -155,10 +159,13 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
             if series_data is not None:
                 series_data = torch.tensor(series_data)
                 result = series_data[window_index : window_index + 
self.context_length]
-                return result[0 : self.input_len].unsqueeze(-1), result[
-                    -self.output_len :
-                ].unsqueeze(-1)
+                return (
+                    result[0 : self.input_len],
+                    result[-self.output_len :],
+                    np.ones(self.token_num, dtype=np.int32),
+                )
         result = []
+        sql = ""
         try:
             if self.cache_enable:
                 sql = self.FETCH_SERIES_SQL % (
@@ -178,13 +185,15 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
                 while query_result.has_next():
                     
result.append(get_field_value(query_result.next().get_fields()[0]))
         except Exception as e:
-            logger.error(e)
+            logger.error("Executing sql: {} with exception: {}".format(sql, e))
         if self.cache_enable:
             self.cache.put(cache_key, result)
         result = torch.tensor(result)
-        return result[0 : self.input_len].unsqueeze(-1), result[
-            -self.output_len :
-        ].unsqueeze(-1)
+        return (
+            result[0 : self.input_len],
+            result[-self.output_len :],
+            np.ones(self.token_num, dtype=np.int32),
+        )
 
     def __len__(self):
         return self.end_idx - self.start_idx
@@ -228,9 +237,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
 
         self.session = TableSession(table_session_config)
         self.context_length = self.input_len + self.output_len
+        self.token_num = self.context_length // self.input_len
         self._fetch_schema(data_schema_list)
 
-        v = self.total_count * start_split
         self.start_index = int(self.total_count * start_split)
         self.end_index = self.total_count * end_split
 
@@ -285,19 +294,52 @@ class 
IoTDBTableModelDataset(BasicDatabaseForecastDataset):
         schema = series.split(".")
 
         result = []
+        sql = self.FETCH_SERIES_SQL % (
+            schema[0:1],
+            schema[2],
+            window_index,
+            self.context_length,
+        )
         try:
-            with self.session.execute_query_statement(
-                self.FETCH_SERIES_SQL
-                % (schema[0:1], schema[2], window_index, self.context_length)
-            ) as query_result:
+            with self.session.execute_query_statement(sql) as query_result:
                 while query_result.has_next():
                     
result.append(get_field_value(query_result.next().get_fields()[0]))
         except Exception as e:
-            logger.error("Error happens when loading dataset str(e))")
+            logger.error("Executing sql: {} with exception: {}".format(sql, e))
         result = torch.tensor(result)
-        return result[0 : self.input_len].unsqueeze(-1), result[
-            -self.output_len :
-        ].unsqueeze(-1)
+        return (
+            result[0 : self.input_len],
+            result[-self.output_len :],
+            np.ones(self.token_num, dtype=np.int32),
+        )
 
     def __len__(self):
         return self.end_index - self.start_index
+
+
+def register_dataset(key: str, dataset: Dataset):
+    DatasetFactory().register(key, dataset)
+
+
+@singleton
+class DatasetFactory(object):
+
+    def __init__(self):
+        self.dataset_list = {
+            "iotdb.table": IoTDBTableModelDataset,
+            "iotdb.tree": IoTDBTreeModelDataset,
+        }
+
+    def register(self, key: str, dataset: Dataset):
+        if key not in self.dataset_list:
+            self.dataset_list[key] = dataset
+        else:
+            raise KeyError(f"Dataset {key} already exists")
+
+    def deregister(self, key: str):
+        del self.dataset_list[key]
+
+    def get_dataset(self, key: str):
+        if key not in self.dataset_list.keys():
+            raise KeyError(f"Dataset {key} does not exist")
+        return self.dataset_list[key]
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index cbd67795774..fa228a233e1 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -27,6 +27,7 @@ import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation;
 import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
 import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration;
 import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.common.rpc.thrift.TFlushReq;
 import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
 import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -41,7 +42,6 @@ import org.apache.iotdb.commons.auth.AuthException;
 import org.apache.iotdb.commons.auth.entity.PrivilegeUnion;
 import org.apache.iotdb.commons.client.ainode.AINodeClient;
 import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
-import org.apache.iotdb.commons.client.ainode.AINodeInfo;
 import org.apache.iotdb.commons.cluster.NodeStatus;
 import org.apache.iotdb.commons.cluster.NodeType;
 import org.apache.iotdb.commons.conf.CommonConfig;
@@ -136,6 +136,7 @@ import 
org.apache.iotdb.confignode.persistence.quota.QuotaInfo;
 import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo;
 import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo;
 import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils;
+import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
 import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq;
 import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq;
 import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp;
@@ -2671,7 +2672,7 @@ public class ConfigManager implements IManager {
       }
     }
     for (String tableName : dataSchemaForTable.getTableList()) {
-      dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT 
+ tableName));
+      dataSchemaList.add(new IDataSchema(tableName));
     }
     return dataSchemaList;
   }
@@ -2685,7 +2686,7 @@ public class ConfigManager implements IManager {
 
     TTrainingReq trainingReq = new TTrainingReq();
     trainingReq.setModelId(req.getModelId());
-    trainingReq.setModelType("timer_xl");
+    trainingReq.setModelType("sundial");
     if (req.existingModelId != null) {
       trainingReq.setExistingModelId(req.getExistingModelId());
     }
@@ -2710,8 +2711,11 @@ public class ConfigManager implements IManager {
       updateModelInfo(new TUpdateModelInfoReq(req.modelId, 
ModelStatus.TRAINING.ordinal()));
       trainingReq.setTargetDataSchema(dataSchema);
 
+      TAINodeInfo registeredAINode = 
getNodeManager().getRegisteredAINodeInfoList().get(0);
+      TEndPoint targetAINodeEndPoint =
+          new TEndPoint(registeredAINode.getInternalAddress(), 
registeredAINode.getInternalPort());
       try (AINodeClient client =
-          AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint)) 
{
+          
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
         status = client.createTrainingTask(trainingReq);
         if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
           throw new IllegalArgumentException(status.message);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index a169096341f..95d9ecaa278 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -1377,8 +1377,8 @@ public class TableConfigTaskVisitor extends 
AstVisitor<IConfigTask, MPPQueryCont
         node.isUseAllData(),
         node.getTargetTimeRanges(),
         node.getExistingModelId(),
-        node.getTargetDbs(),
-        tableList);
+        tableList,
+        node.getTargetDbs());
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
index 84a6aa45f6d..91d3258dba1 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -53,9 +53,6 @@ public class CreateTrainingTask implements IConfigTask {
       String existingModelId,
       List<String> targetTables,
       List<String> targetDbs) {
-    if (!modelType.equalsIgnoreCase("timer_xl")) {
-      throw new UnsupportedOperationException("Only TimerXL model is supported 
now.");
-    }
     this.modelId = modelId;
     this.modelType = modelType;
     this.parameters = parameters;
@@ -76,9 +73,6 @@ public class CreateTrainingTask implements IConfigTask {
       List<List<Long>> timeRanges,
       String existingModelId,
       List<String> targetPaths) {
-    if (!modelType.equalsIgnoreCase("timer_xl")) {
-      throw new UnsupportedOperationException("Only TimerXL model is supported 
now.");
-    }
     this.modelId = modelId;
     this.modelType = modelType;
     this.parameters = parameters;
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index ae7a521a19a..346a459136a 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -199,7 +199,7 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
       TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
       tsStatus.setMessage(
           String.format(
-              "Failed to connect to AINode from DataNode when executing %s: 
%s",
+              "Failed to connect to AINode when executing %s: %s",
               Thread.currentThread().getStackTrace()[1].getMethodName(), 
e.getMessage()));
       return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
     }
@@ -210,7 +210,7 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
       return client.createTrainingTask(req);
     } catch (TException e) {
       logger.warn(
-          "Failed to connect to AINode from DataNode when executing {}: {}",
+          "Failed to connect to AINode when executing {}: {}",
           Thread.currentThread().getStackTrace()[1].getMethodName(),
           e.getMessage());
       throw new TException(MSG_CONNECTION_FAIL);
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
deleted file mode 100644
index d6f3a655279..00000000000
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.commons.client.ainode;
-
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
-import org.apache.iotdb.commons.conf.CommonDescriptor;
-
-public class AINodeInfo {
-  // currently, we only support one AINode
-  public static final TEndPoint endPoint =
-      CommonDescriptor.getInstance().getConfig().getTargetAINodeEndPoint();
-}
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
index e82cc2fb1dd..e5960a57c6b 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
@@ -19,7 +19,6 @@
 
 package org.apache.iotdb.commons.conf;
 
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import 
org.apache.iotdb.commons.client.property.ClientPoolProperty.DefaultProperty;
 import org.apache.iotdb.commons.cluster.NodeStatus;
 import org.apache.iotdb.commons.enums.HandleSystemErrorStrategy;
@@ -174,9 +173,6 @@ public class CommonConfig {
   /** Disk Monitor. */
   private double diskSpaceWarningThreshold = 0.05;
 
-  /** Ip and port of target AI node. */
-  private TEndPoint targetAINodeEndPoint = new TEndPoint("127.0.0.1", 10810);
-
   /** Time partition origin in milliseconds. */
   private long timePartitionOrigin = 0;
 
@@ -662,14 +658,6 @@ public class CommonConfig {
     this.statusReason = statusReason;
   }
 
-  public TEndPoint getTargetAINodeEndPoint() {
-    return targetAINodeEndPoint;
-  }
-
-  public void setTargetAINodeEndPoint(TEndPoint targetAINodeEndPoint) {
-    this.targetAINodeEndPoint = targetAINodeEndPoint;
-  }
-
   public int getTTimePartitionSlotTransmitLimit() {
     return TTimePartitionSlotTransmitLimit;
   }
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift 
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index 3f4c42d058a..f767f35d67c 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1088,7 +1088,6 @@ struct TUpdateModelInfoReq {
 struct TDataSchemaForTable{
     1: required list<string> databaseList
     2: required list<string> tableList
-    3: required string curDatabase
 }
 
 struct TDataSchemaForTree{

Reply via email to