This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 0a0ec3fb176 [AINode] Fix model management bugs (#15890)
0a0ec3fb176 is described below
commit 0a0ec3fb1763cc234481b72dbac4fda95fb34b72
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 9 11:58:43 2025 +0800
[AINode] Fix model management bugs (#15890)
---
iotdb-core/ainode/ainode/core/config.py | 3 +-
iotdb-core/ainode/ainode/core/constant.py | 3 +
iotdb-core/ainode/ainode/core/handler.py | 5 +-
.../ainode/core/manager/inference_manager.py | 2 +-
.../ainode/ainode/core/manager/model_manager.py | 9 +-
.../ainode/ainode/core/model/model_storage.py | 98 ++++++++++++++++------
.../iotdb/confignode/manager/ModelManager.java | 7 +-
.../procedure/impl/model/DropModelProcedure.java | 57 +++++++------
.../iotdb/commons/client/ainode/AINodeClient.java | 5 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 6 +-
10 files changed, 127 insertions(+), 68 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/config.py
b/iotdb-core/ainode/ainode/core/config.py
index b7dbf3fc94c..b4694cb9c3b 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -38,6 +38,7 @@ from ainode.core.constant import (
AINODE_ROOT_DIR,
AINODE_SYSTEM_DIR,
AINODE_SYSTEM_FILE_NAME,
+ AINODE_TARGET_CONFIG_NODE_LIST,
AINODE_THRIFT_COMPRESSION_ENABLED,
AINODE_VERSION_INFO,
)
@@ -73,7 +74,7 @@ class AINodeConfig(object):
self._ain_model_storage_cache_size = 30
# Target ConfigNode to be connected by AINode
- self._ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1",
10710)
+ self._ain_target_config_node_list: TEndPoint =
AINODE_TARGET_CONFIG_NODE_LIST
# use for node management
self._ainode_id = 0
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index c0e021ceebf..c307dbafe63 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -21,6 +21,8 @@ import os
from enum import Enum
from typing import List
+from ainode.thrift.common.ttypes import TEndPoint
+
AINODE_CONF_DIRECTORY_NAME = "conf"
AINODE_ROOT_CONF_DIRECTORY_NAME = "conf"
AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
@@ -49,6 +51,7 @@ AINODE_CLUSTER_INGRESS_PORT = 6667
AINODE_CLUSTER_INGRESS_USERNAME = "root"
AINODE_CLUSTER_INGRESS_PASSWORD = "root"
AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
+AINODE_TARGET_CONFIG_NODE_LIST = TEndPoint("127.0.0.1", 10710)
# AINode log
AINODE_LOG_FILE_NAMES = [
diff --git a/iotdb-core/ainode/ainode/core/handler.py
b/iotdb-core/ainode/ainode/core/handler.py
index 804be637934..524b80a88d8 100644
--- a/iotdb-core/ainode/ainode/core/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -30,6 +30,7 @@ from ainode.thrift.ainode.ttypes import (
TInferenceResp,
TRegisterModelReq,
TRegisterModelResp,
+ TShowModelsReq,
TShowModelsResp,
TTrainingReq,
)
@@ -58,8 +59,8 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
- def showModels(self) -> TShowModelsResp:
- return self._model_manager.show_models()
+ def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
+ return self._model_manager.show_models(req)
def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
pass
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 9d092c78c7c..9eda1c22651 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -130,7 +130,7 @@ class InferenceManager:
return TimerXLStrategy(model)
if isinstance(model, SundialForPrediction):
return SundialStrategy(model)
- if self.model_manager.model_storage._is_built_in(model_id):
+ if
self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
return BuiltInStrategy(model)
return RegisteredStrategy(model)
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index 4688edf04c3..bb589a281bf 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -33,6 +33,7 @@ from ainode.thrift.ainode.ttypes import (
TDeleteModelReq,
TRegisterModelReq,
TRegisterModelResp,
+ TShowModelsReq,
TShowModelsResp,
)
from ainode.thrift.common.ttypes import TSStatus
@@ -55,19 +56,16 @@ class ModelManager:
)
except InvalidUriError as e:
logger.warning(e)
- self.model_storage.delete_model(req.modelId)
return TRegisterModelResp(
get_status(TSStatusCode.INVALID_URI_ERROR, e.message)
)
except BadConfigValueError as e:
logger.warning(e)
- self.model_storage.delete_model(req.modelId)
return TRegisterModelResp(
get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)
)
except YAMLError as e:
logger.warning(e)
- self.model_storage.delete_model(req.modelId)
if hasattr(e, "problem_mark"):
mark = e.problem_mark
return TRegisterModelResp(
@@ -85,7 +83,6 @@ class ModelManager:
)
except Exception as e:
logger.warning(e)
- self.model_storage.delete_model(req.modelId)
return
TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
def delete_model(self, req: TDeleteModelReq) -> TSStatus:
@@ -141,8 +138,8 @@ class ModelManager:
"""
return self.model_storage.get_ckpt_path(model_id)
- def show_models(self) -> TShowModelsResp:
- return self.model_storage.show_models()
+ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
+ return self.model_storage.show_models(req)
def register_built_in_model(self, model_info: ModelInfo):
self.model_storage.register_built_in_model(model_info)
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 15727392e0b..5a544cc0fea 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -50,7 +50,7 @@ from ainode.core.model.model_info import (
get_built_in_model_type,
)
from ainode.core.util.lock import ModelLockPool
-from ainode.thrift.ainode.ttypes import TShowModelsResp
+from ainode.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp
from ainode.thrift.common.ttypes import TSStatus
logger = Logger()
@@ -211,23 +211,30 @@ class ModelStorage(object):
configs: TConfigs
attributes: str
"""
- storage_path = os.path.join(self._model_dir, f"{model_id}")
- # create storage dir if not exist
- if not os.path.exists(storage_path):
- os.makedirs(storage_path)
- model_storage_path = os.path.join(storage_path,
DEFAULT_MODEL_FILE_NAME)
- config_storage_path = os.path.join(storage_path,
DEFAULT_CONFIG_FILE_NAME)
- configs, attributes = fetch_model_by_uri(
- uri, model_storage_path, config_storage_path
- )
- model_info = ModelInfo(
- model_id=model_id,
- model_type="",
- category=ModelCategory.USER_DEFINED,
- state=ModelStates.ACTIVE,
- )
- self.register_built_in_model(model_info)
- return configs, attributes
+ with self._lock_pool.get_lock(model_id).write_lock():
+ storage_path = os.path.join(self._model_dir, f"{model_id}")
+ # create storage dir if not exist
+ if not os.path.exists(storage_path):
+ os.makedirs(storage_path)
+ model_storage_path = os.path.join(storage_path,
DEFAULT_MODEL_FILE_NAME)
+ config_storage_path = os.path.join(storage_path,
DEFAULT_CONFIG_FILE_NAME)
+ self._model_info_map[model_id] = ModelInfo(
+ model_id=model_id,
+ model_type="",
+ category=ModelCategory.USER_DEFINED,
+ state=ModelStates.LOADING,
+ )
+ try:
+ # TODO: The uri should be fetched asynchronously
+ configs, attributes = fetch_model_by_uri(
+ uri, model_storage_path, config_storage_path
+ )
+ self._model_info_map[model_id].state = ModelStates.ACTIVE
+ return configs, attributes
+ except Exception as e:
+ logger.error(f"Failed to register model {model_id}: {e}")
+ self._model_info_map[model_id].state = ModelStates.INACTIVE
+ raise e
def delete_model(self, model_id: str) -> None:
"""
@@ -241,9 +248,12 @@ class ModelStorage(object):
if self._is_built_in(model_id):
raise BuiltInModelDeletionError(model_id)
- # delete the user-defined model
- storage_path = os.path.join(self._model_dir, f"{model_id}")
+ # delete the user-defined or fine-tuned model
with self._lock_pool.get_lock(model_id).write_lock():
+ storage_path = os.path.join(self._model_dir, f"{model_id}")
+ if os.path.exists(storage_path):
+ shutil.rmtree(storage_path)
+ storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
if os.path.exists(storage_path):
shutil.rmtree(storage_path)
if model_id in self._model_info_map:
@@ -260,6 +270,21 @@ class ModelStorage(object):
Returns:
bool: True if the model is built-in, False otherwise.
"""
+ return (
+ model_id in self._model_info_map
+ and self._model_info_map[model_id].category ==
ModelCategory.BUILT_IN
+ )
+
+ def _is_built_in_or_fine_tuned(self, model_id: str) -> bool:
+ """
+ Check if the model_id corresponds to a built-in or fine-tuned model.
+
+ Args:
+ model_id (str): The ID of the model.
+
+ Returns:
+ bool: True if the model is built-in or fine_tuned, False otherwise.
+ """
return model_id in self._model_info_map and (
self._model_info_map[model_id].category == ModelCategory.BUILT_IN
or self._model_info_map[model_id].category ==
ModelCategory.FINE_TUNED
@@ -275,7 +300,7 @@ class ModelStorage(object):
model: The model instance corresponding to specific model_id
"""
with self._lock_pool.get_lock(model_id).read_lock():
- if self._is_built_in(model_id):
+ if self._is_built_in_or_fine_tuned(model_id):
model_dir = os.path.join(self._builtin_model_dir,
f"{model_id}")
return fetch_built_in_model(
get_built_in_model_type(self._model_info_map[model_id].model_type),
@@ -312,7 +337,7 @@ class ModelStorage(object):
Whether saving succeeded
"""
with self._lock_pool.get_lock(model_id).write_lock():
- if self._is_built_in(model_id):
+ if self._is_built_in_or_fine_tuned(model_id):
model_dir = os.path.join(self._builtin_model_dir,
f"{model_id}")
model.save_pretrained(model_dir)
else:
@@ -343,12 +368,31 @@ class ModelStorage(object):
# Only support built-in models for now
return os.path.join(self._builtin_model_dir, f"{model_id}")
- def show_models(self) -> TShowModelsResp:
+ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
+ resp_status = TSStatus(
+ code=TSStatusCode.SUCCESS_STATUS.value,
+ message="Show models successfully",
+ )
+ if req.modelId:
+ if req.modelId in self._model_info_map:
+ model_info = self._model_info_map[req.modelId]
+ return TShowModelsResp(
+ status=resp_status,
+ modelIdList=[req.modelId],
+ modelTypeMap={req.modelId: model_info.model_type},
+ categoryMap={req.modelId: model_info.category.value},
+ stateMap={req.modelId: model_info.state.value},
+ )
+ else:
+ return TShowModelsResp(
+ status=resp_status,
+ modelIdList=[],
+ modelTypeMap={},
+ categoryMap={},
+ stateMap={},
+ )
return TShowModelsResp(
- status=TSStatus(
- code=TSStatusCode.SUCCESS_STATUS.value,
- message="Show models successfully",
- ),
+ status=resp_status,
modelIdList=list(self._model_info_map.keys()),
modelTypeMap=dict(
(model_id, model_info.model_type)
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index e8ee032d0e5..ad1a6bed4ed 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.confignode.manager;
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -105,7 +106,11 @@ public class ModelManager {
new TEndPoint(registeredAINode.getInternalAddress(),
registeredAINode.getInternalPort());
try (AINodeClient client =
AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
- TShowModelsResp resp = client.showModels();
+ TShowModelsReq showModelsReq = new TShowModelsReq();
+ if (req.isSetModelId()) {
+ showModelsReq.setModelId(req.getModelId());
+ }
+ TShowModelsResp resp = client.showModels(showModelsReq);
TShowModelResp res =
new TShowModelResp().setStatus(new
TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
res.setModelIdList(resp.getModelIdList());
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index 5a8f8325403..23e02ea2e1d 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.confignode.procedure.impl.model;
+import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.client.ainode.AINodeClient;
import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
@@ -101,33 +102,35 @@ public class DropModelProcedure extends
AbstractNodeProcedure<DropModelState> {
private void dropModelOnAINode(ConfigNodeProcedureEnv env) {
LOGGER.info("Start to drop model file [{}] on AI Node", modelName);
- List<Integer> nodeIds =
-
env.getConfigManager().getModelManager().getModelDistributions(modelName);
- for (Integer nodeId : nodeIds) {
- try (AINodeClient client =
- AINodeClientManager.getInstance()
- .borrowClient(
- env.getConfigManager()
- .getNodeManager()
- .getRegisteredAINode(nodeId)
- .getLocation()
- .getInternalEndPoint())) {
- TSStatus status = client.deleteModel(modelName);
- if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- LOGGER.warn(
- "Failed to drop model [{}] on AINode [{}], status: {}",
- modelName,
- nodeId,
- status.getMessage());
- }
- } catch (Exception e) {
- LOGGER.warn(
- "Failed to drop model [{}] on AINode [{}], status: {}",
- modelName,
- nodeId,
- e.getMessage());
- }
- }
+ List<TAINodeConfiguration> aiNodes =
+ env.getConfigManager().getNodeManager().getRegisteredAINodes();
+ aiNodes.forEach(
+ aiNode -> {
+ int nodeId = aiNode.getLocation().getAiNodeId();
+ try (AINodeClient client =
+ AINodeClientManager.getInstance()
+ .borrowClient(
+ env.getConfigManager()
+ .getNodeManager()
+ .getRegisteredAINode(nodeId)
+ .getLocation()
+ .getInternalEndPoint())) {
+ TSStatus status = client.deleteModel(modelName);
+ if (status.getCode() !=
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+ LOGGER.warn(
+ "Failed to drop model [{}] on AINode [{}], status: {}",
+ modelName,
+ nodeId,
+ status.getMessage());
+ }
+ } catch (Exception e) {
+ LOGGER.warn(
+ "Failed to drop model [{}] on AINode [{}], status: {}",
+ modelName,
+ nodeId,
+ e.getMessage());
+ }
+ });
}
private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) {
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index f573ea03b6d..e52310d1505 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -28,6 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
@@ -159,9 +160,9 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
}
}
- public TShowModelsResp showModels() throws TException {
+ public TShowModelsResp showModels(TShowModelsReq req) throws TException {
try {
- return client.showModels();
+ return client.showModels(req);
} catch (TException e) {
logger.warn(
"Failed to connect to AINode from ConfigNode when executing {}: {}",
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index df6cf5daca3..a4ccef7e752 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -97,6 +97,10 @@ struct TForecastResp {
2: required binary forecastResult
}
+struct TShowModelsReq {
+ 1: optional string modelId
+}
+
struct TShowModelsResp {
1: required common.TSStatus status
2: optional list<string> modelIdList
@@ -108,7 +112,7 @@ struct TShowModelsResp {
service IAINodeRPCService {
// -------------- For Config Node --------------
- TShowModelsResp showModels()
+ TShowModelsResp showModels(TShowModelsReq req)
common.TSStatus deleteModel(TDeleteModelReq req)