This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 7d5b7c50809 [AINode] Enhance the robustness of AINode (#15695)
7d5b7c50809 is described below
commit 7d5b7c508092acebd847c2daf4d1335ad6c68a0c
Author: Yongzao <[email protected]>
AuthorDate: Thu Jun 12 18:57:41 2025 +0800
[AINode] Enhance the robustness of AINode (#15695)
---
.../org/apache/iotdb/ainode/it/AINodeBasicIT.java | 8 +--
iotdb-core/ainode/ainode/core/constant.py | 4 +-
iotdb-core/ainode/ainode/core/ingress/dataset.py | 32 +--------
iotdb-core/ainode/ainode/core/ingress/iotdb.py | 84 ++++++++++++++++------
.../iotdb/confignode/manager/ConfigManager.java | 12 ++--
.../execution/config/TableConfigTaskVisitor.java | 4 +-
.../config/metadata/ai/CreateTrainingTask.java | 6 --
.../iotdb/commons/client/ainode/AINodeClient.java | 4 +-
.../iotdb/commons/client/ainode/AINodeInfo.java | 29 --------
.../apache/iotdb/commons/conf/CommonConfig.java | 12 ----
.../src/main/thrift/confignode.thrift | 1 -
11 files changed, 80 insertions(+), 116 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
index b503d67f814..da4f8e9c536 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
@@ -20,14 +20,10 @@
package org.apache.iotdb.ainode.it;
import org.apache.iotdb.it.env.EnvFactory;
-import org.apache.iotdb.it.framework.IoTDBTestRunner;
-import org.apache.iotdb.itbase.category.AIClusterIT;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
-import org.junit.experimental.categories.Category;
-import org.junit.runner.RunWith;
import java.io.File;
import java.sql.Connection;
@@ -40,8 +36,8 @@ import static
org.apache.iotdb.db.it.utils.TestUtils.prepareData;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
-@RunWith(IoTDBTestRunner.class)
-@Category({AIClusterIT.class})
+// @RunWith(IoTDBTestRunner.class)
+// @Category({AIClusterIT.class})
public class AINodeBasicIT {
static final String MODEL_PATH =
System.getProperty("user.dir")
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index 24d13a12ab8..98a66abc019 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -53,8 +53,8 @@ AINODE_LOG_FILE_LEVELS = [logging.DEBUG, logging.INFO,
logging.WARNING, logging.
TRIAL_ID_PREFIX = "__trial_"
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
-DEFAULT_MODEL_FILE_NAME = "model.pt"
-DEFAULT_CONFIG_FILE_NAME = "config.yaml"
+DEFAULT_MODEL_FILE_NAME = "model.safetensors"
+DEFAULT_CONFIG_FILE_NAME = "config.json"
DEFAULT_CHUNK_SIZE = 8192
DEFAULT_RECONNECT_TIMEOUT = 20
diff --git a/iotdb-core/ainode/ainode/core/ingress/dataset.py
b/iotdb-core/ainode/ainode/core/ingress/dataset.py
index c2410ed4374..9783c6c85c1 100644
--- a/iotdb-core/ainode/ainode/core/ingress/dataset.py
+++ b/iotdb-core/ainode/ainode/core/ingress/dataset.py
@@ -15,10 +15,8 @@
# specific language governing permissions and limitations
# under the License.
#
-from torch.utils.data import Dataset
-from ainode.core.ingress.iotdb import IoTDBTableModelDataset,
IoTDBTreeModelDataset
-from ainode.core.util.decorator import singleton
+from torch.utils.data import Dataset
class BasicDatabaseDataset(Dataset):
@@ -32,31 +30,3 @@ class BasicDatabaseForecastDataset(BasicDatabaseDataset):
super().__init__(ip, port)
self.input_len = input_len
self.output_len = output_len
-
-
-def register_dataset(key: str, dataset: Dataset):
- DatasetFactory().register(key, dataset)
-
-
-@singleton
-class DatasetFactory(object):
-
- def __init__(self):
- self.dataset_list = {
- "iotdb.table": IoTDBTableModelDataset,
- "iotdb.tree": IoTDBTreeModelDataset,
- }
-
- def register(self, key: str, dataset: Dataset):
- if key not in self.dataset_list:
- self.dataset_list[key] = dataset
- else:
- raise KeyError(f"Dataset {key} already exists")
-
- def deregister(self, key: str):
- del self.dataset_list[key]
-
- def get_dataset(self, key: str):
- if key not in self.dataset_list.keys():
- raise KeyError(f"Dataset {key} does not exist")
- return self.dataset_list[key]
diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index 4b034ac8808..d1b344b43a8 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -15,16 +15,19 @@
# specific language governing permissions and limitations
# under the License.
#
+import numpy as np
import torch
from iotdb.Session import Session
from iotdb.table_session import TableSession, TableSessionConfig
from iotdb.utils.Field import Field
from iotdb.utils.IoTDBConstants import TSDataType
-from util.cache import MemoryLRUCache
+from torch.utils.data import Dataset
from ainode.core.config import AINodeDescriptor
from ainode.core.ingress.dataset import BasicDatabaseForecastDataset
from ainode.core.log import Logger
+from ainode.core.util.cache import MemoryLRUCache
+from ainode.core.util.decorator import singleton
logger = Logger()
@@ -55,7 +58,7 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
model_id: str,
input_len: int,
out_len: int,
- schema_list: list,
+ data_schema_list: list,
ip: str = "127.0.0.1",
port: int = 6667,
username: str = "root",
@@ -81,15 +84,16 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
)
self.session.open(False)
self.context_length = self.input_len + self.output_len
- self._fetch_schema(schema_list)
+ self.token_num = self.context_length // self.input_len
+ self._fetch_schema(data_schema_list)
self.start_idx = int(self.total_count * start_split)
self.end_idx = int(self.total_count * end_split)
self.cache_enable = _cache_enable()
self.cache_key_prefix = model_id + "_"
- def _fetch_schema(self, schema_list: list):
+ def _fetch_schema(self, data_schema_list: list):
series_to_length = {}
- for schema in schema_list:
+ for schema in data_schema_list:
path_pattern = schema.schemaName
series_list = []
time_condition = (
@@ -155,10 +159,13 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
if series_data is not None:
series_data = torch.tensor(series_data)
result = series_data[window_index : window_index +
self.context_length]
- return result[0 : self.input_len].unsqueeze(-1), result[
- -self.output_len :
- ].unsqueeze(-1)
+ return (
+ result[0 : self.input_len],
+ result[-self.output_len :],
+ np.ones(self.token_num, dtype=np.int32),
+ )
result = []
+ sql = ""
try:
if self.cache_enable:
sql = self.FETCH_SERIES_SQL % (
@@ -178,13 +185,15 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
except Exception as e:
- logger.error(e)
+ logger.error("Executing sql: {} with exception: {}".format(sql, e))
if self.cache_enable:
self.cache.put(cache_key, result)
result = torch.tensor(result)
- return result[0 : self.input_len].unsqueeze(-1), result[
- -self.output_len :
- ].unsqueeze(-1)
+ return (
+ result[0 : self.input_len],
+ result[-self.output_len :],
+ np.ones(self.token_num, dtype=np.int32),
+ )
def __len__(self):
return self.end_idx - self.start_idx
@@ -228,9 +237,9 @@ class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
self.session = TableSession(table_session_config)
self.context_length = self.input_len + self.output_len
+ self.token_num = self.context_length // self.input_len
self._fetch_schema(data_schema_list)
- v = self.total_count * start_split
self.start_index = int(self.total_count * start_split)
self.end_index = self.total_count * end_split
@@ -285,19 +294,52 @@ class
IoTDBTableModelDataset(BasicDatabaseForecastDataset):
schema = series.split(".")
result = []
+ sql = self.FETCH_SERIES_SQL % (
+ schema[0:1],
+ schema[2],
+ window_index,
+ self.context_length,
+ )
try:
- with self.session.execute_query_statement(
- self.FETCH_SERIES_SQL
- % (schema[0:1], schema[2], window_index, self.context_length)
- ) as query_result:
+ with self.session.execute_query_statement(sql) as query_result:
while query_result.has_next():
result.append(get_field_value(query_result.next().get_fields()[0]))
except Exception as e:
- logger.error("Error happens when loading dataset str(e))")
+ logger.error("Executing sql: {} with exception: {}".format(sql, e))
result = torch.tensor(result)
- return result[0 : self.input_len].unsqueeze(-1), result[
- -self.output_len :
- ].unsqueeze(-1)
+ return (
+ result[0 : self.input_len],
+ result[-self.output_len :],
+ np.ones(self.token_num, dtype=np.int32),
+ )
def __len__(self):
return self.end_index - self.start_index
+
+
+def register_dataset(key: str, dataset: Dataset):
+ DatasetFactory().register(key, dataset)
+
+
+@singleton
+class DatasetFactory(object):
+
+ def __init__(self):
+ self.dataset_list = {
+ "iotdb.table": IoTDBTableModelDataset,
+ "iotdb.tree": IoTDBTreeModelDataset,
+ }
+
+ def register(self, key: str, dataset: Dataset):
+ if key not in self.dataset_list:
+ self.dataset_list[key] = dataset
+ else:
+ raise KeyError(f"Dataset {key} already exists")
+
+ def deregister(self, key: str):
+ del self.dataset_list[key]
+
+ def get_dataset(self, key: str):
+ if key not in self.dataset_list.keys():
+ raise KeyError(f"Dataset {key} does not exist")
+ return self.dataset_list[key]
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index cbd67795774..fa228a233e1 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -27,6 +27,7 @@ import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TFlushReq;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -41,7 +42,6 @@ import org.apache.iotdb.commons.auth.AuthException;
import org.apache.iotdb.commons.auth.entity.PrivilegeUnion;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
-import org.apache.iotdb.commons.client.ainode.AINodeInfo;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.cluster.NodeType;
import org.apache.iotdb.commons.conf.CommonConfig;
@@ -136,6 +136,7 @@ import
org.apache.iotdb.confignode.persistence.quota.QuotaInfo;
import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo;
import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo;
import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils;
+import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp;
@@ -2671,7 +2672,7 @@ public class ConfigManager implements IManager {
}
}
for (String tableName : dataSchemaForTable.getTableList()) {
- dataSchemaList.add(new IDataSchema(dataSchemaForTable.curDatabase + DOT
+ tableName));
+ dataSchemaList.add(new IDataSchema(tableName));
}
return dataSchemaList;
}
@@ -2685,7 +2686,7 @@ public class ConfigManager implements IManager {
TTrainingReq trainingReq = new TTrainingReq();
trainingReq.setModelId(req.getModelId());
- trainingReq.setModelType("timer_xl");
+ trainingReq.setModelType("sundial");
if (req.existingModelId != null) {
trainingReq.setExistingModelId(req.getExistingModelId());
}
@@ -2710,8 +2711,11 @@ public class ConfigManager implements IManager {
updateModelInfo(new TUpdateModelInfoReq(req.modelId,
ModelStatus.TRAINING.ordinal()));
trainingReq.setTargetDataSchema(dataSchema);
+ TAINodeInfo registeredAINode =
getNodeManager().getRegisteredAINodeInfoList().get(0);
+ TEndPoint targetAINodeEndPoint =
+ new TEndPoint(registeredAINode.getInternalAddress(),
registeredAINode.getInternalPort());
try (AINodeClient client =
- AINodeClientManager.getInstance().borrowClient(AINodeInfo.endPoint))
{
+
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
status = client.createTrainingTask(trainingReq);
if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
throw new IllegalArgumentException(status.message);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index a169096341f..95d9ecaa278 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -1377,8 +1377,8 @@ public class TableConfigTaskVisitor extends
AstVisitor<IConfigTask, MPPQueryCont
node.isUseAllData(),
node.getTargetTimeRanges(),
node.getExistingModelId(),
- node.getTargetDbs(),
- tableList);
+ tableList,
+ node.getTargetDbs());
}
@Override
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
index 84a6aa45f6d..91d3258dba1 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -53,9 +53,6 @@ public class CreateTrainingTask implements IConfigTask {
String existingModelId,
List<String> targetTables,
List<String> targetDbs) {
- if (!modelType.equalsIgnoreCase("timer_xl")) {
- throw new UnsupportedOperationException("Only TimerXL model is supported
now.");
- }
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
@@ -76,9 +73,6 @@ public class CreateTrainingTask implements IConfigTask {
List<List<Long>> timeRanges,
String existingModelId,
List<String> targetPaths) {
- if (!modelType.equalsIgnoreCase("timer_xl")) {
- throw new UnsupportedOperationException("Only TimerXL model is supported
now.");
- }
this.modelId = modelId;
this.modelType = modelType;
this.parameters = parameters;
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index ae7a521a19a..346a459136a 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -199,7 +199,7 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode());
tsStatus.setMessage(
String.format(
- "Failed to connect to AINode from DataNode when executing %s:
%s",
+ "Failed to connect to AINode when executing %s: %s",
Thread.currentThread().getStackTrace()[1].getMethodName(),
e.getMessage()));
return new TForecastResp(tsStatus, ByteBuffer.allocate(0));
}
@@ -210,7 +210,7 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
return client.createTrainingTask(req);
} catch (TException e) {
logger.warn(
- "Failed to connect to AINode from DataNode when executing {}: {}",
+ "Failed to connect to AINode when executing {}: {}",
Thread.currentThread().getStackTrace()[1].getMethodName(),
e.getMessage());
throw new TException(MSG_CONNECTION_FAIL);
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
deleted file mode 100644
index d6f3a655279..00000000000
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeInfo.java
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.commons.client.ainode;
-
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
-import org.apache.iotdb.commons.conf.CommonDescriptor;
-
-public class AINodeInfo {
- // currently, we only support one AINode
- public static final TEndPoint endPoint =
- CommonDescriptor.getInstance().getConfig().getTargetAINodeEndPoint();
-}
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
index e82cc2fb1dd..e5960a57c6b 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/conf/CommonConfig.java
@@ -19,7 +19,6 @@
package org.apache.iotdb.commons.conf;
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import
org.apache.iotdb.commons.client.property.ClientPoolProperty.DefaultProperty;
import org.apache.iotdb.commons.cluster.NodeStatus;
import org.apache.iotdb.commons.enums.HandleSystemErrorStrategy;
@@ -174,9 +173,6 @@ public class CommonConfig {
/** Disk Monitor. */
private double diskSpaceWarningThreshold = 0.05;
- /** Ip and port of target AI node. */
- private TEndPoint targetAINodeEndPoint = new TEndPoint("127.0.0.1", 10810);
-
/** Time partition origin in milliseconds. */
private long timePartitionOrigin = 0;
@@ -662,14 +658,6 @@ public class CommonConfig {
this.statusReason = statusReason;
}
- public TEndPoint getTargetAINodeEndPoint() {
- return targetAINodeEndPoint;
- }
-
- public void setTargetAINodeEndPoint(TEndPoint targetAINodeEndPoint) {
- this.targetAINodeEndPoint = targetAINodeEndPoint;
- }
-
public int getTTimePartitionSlotTransmitLimit() {
return TTimePartitionSlotTransmitLimit;
}
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index 3f4c42d058a..f767f35d67c 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1088,7 +1088,6 @@ struct TUpdateModelInfoReq {
struct TDataSchemaForTable{
1: required list<string> databaseList
2: required list<string> tableList
- 3: required string curDatabase
}
struct TDataSchemaForTree{