This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a0ec3fb176 [AINode] Fix model management bugs (#15890)
0a0ec3fb176 is described below

commit 0a0ec3fb1763cc234481b72dbac4fda95fb34b72
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 9 11:58:43 2025 +0800

    [AINode] Fix model management bugs (#15890)
---
 iotdb-core/ainode/ainode/core/config.py            |  3 +-
 iotdb-core/ainode/ainode/core/constant.py          |  3 +
 iotdb-core/ainode/ainode/core/handler.py           |  5 +-
 .../ainode/core/manager/inference_manager.py       |  2 +-
 .../ainode/ainode/core/manager/model_manager.py    |  9 +-
 .../ainode/ainode/core/model/model_storage.py      | 98 ++++++++++++++++------
 .../iotdb/confignode/manager/ModelManager.java     |  7 +-
 .../procedure/impl/model/DropModelProcedure.java   | 57 +++++++------
 .../iotdb/commons/client/ainode/AINodeClient.java  |  5 +-
 .../thrift-ainode/src/main/thrift/ainode.thrift    |  6 +-
 10 files changed, 127 insertions(+), 68 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/config.py 
b/iotdb-core/ainode/ainode/core/config.py
index b7dbf3fc94c..b4694cb9c3b 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -38,6 +38,7 @@ from ainode.core.constant import (
     AINODE_ROOT_DIR,
     AINODE_SYSTEM_DIR,
     AINODE_SYSTEM_FILE_NAME,
+    AINODE_TARGET_CONFIG_NODE_LIST,
     AINODE_THRIFT_COMPRESSION_ENABLED,
     AINODE_VERSION_INFO,
 )
@@ -73,7 +74,7 @@ class AINodeConfig(object):
         self._ain_model_storage_cache_size = 30
 
         # Target ConfigNode to be connected by AINode
-        self._ain_target_config_node_list: TEndPoint = TEndPoint("127.0.0.1", 
10710)
+        self._ain_target_config_node_list: TEndPoint = 
AINODE_TARGET_CONFIG_NODE_LIST
 
         # use for node management
         self._ainode_id = 0
diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index c0e021ceebf..c307dbafe63 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -21,6 +21,8 @@ import os
 from enum import Enum
 from typing import List
 
+from ainode.thrift.common.ttypes import TEndPoint
+
 AINODE_CONF_DIRECTORY_NAME = "conf"
 AINODE_ROOT_CONF_DIRECTORY_NAME = "conf"
 AINODE_CONF_FILE_NAME = "iotdb-ainode.properties"
@@ -49,6 +51,7 @@ AINODE_CLUSTER_INGRESS_PORT = 6667
 AINODE_CLUSTER_INGRESS_USERNAME = "root"
 AINODE_CLUSTER_INGRESS_PASSWORD = "root"
 AINODE_CLUSTER_INGRESS_TIME_ZONE = "UTC+8"
+AINODE_TARGET_CONFIG_NODE_LIST = TEndPoint("127.0.0.1", 10710)
 
 # AINode log
 AINODE_LOG_FILE_NAMES = [
diff --git a/iotdb-core/ainode/ainode/core/handler.py 
b/iotdb-core/ainode/ainode/core/handler.py
index 804be637934..524b80a88d8 100644
--- a/iotdb-core/ainode/ainode/core/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -30,6 +30,7 @@ from ainode.thrift.ainode.ttypes import (
     TInferenceResp,
     TRegisterModelReq,
     TRegisterModelResp,
+    TShowModelsReq,
     TShowModelsResp,
     TTrainingReq,
 )
@@ -58,8 +59,8 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
     def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
         return ClusterManager.get_heart_beat(req)
 
-    def showModels(self) -> TShowModelsResp:
-        return self._model_manager.show_models()
+    def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
+        return self._model_manager.show_models(req)
 
     def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
         pass
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 9d092c78c7c..9eda1c22651 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -130,7 +130,7 @@ class InferenceManager:
             return TimerXLStrategy(model)
         if isinstance(model, SundialForPrediction):
             return SundialStrategy(model)
-        if self.model_manager.model_storage._is_built_in(model_id):
+        if 
self.model_manager.model_storage._is_built_in_or_fine_tuned(model_id):
             return BuiltInStrategy(model)
         return RegisteredStrategy(model)
 
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py 
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index 4688edf04c3..bb589a281bf 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -33,6 +33,7 @@ from ainode.thrift.ainode.ttypes import (
     TDeleteModelReq,
     TRegisterModelReq,
     TRegisterModelResp,
+    TShowModelsReq,
     TShowModelsResp,
 )
 from ainode.thrift.common.ttypes import TSStatus
@@ -55,19 +56,16 @@ class ModelManager:
             )
         except InvalidUriError as e:
             logger.warning(e)
-            self.model_storage.delete_model(req.modelId)
             return TRegisterModelResp(
                 get_status(TSStatusCode.INVALID_URI_ERROR, e.message)
             )
         except BadConfigValueError as e:
             logger.warning(e)
-            self.model_storage.delete_model(req.modelId)
             return TRegisterModelResp(
                 get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)
             )
         except YAMLError as e:
             logger.warning(e)
-            self.model_storage.delete_model(req.modelId)
             if hasattr(e, "problem_mark"):
                 mark = e.problem_mark
                 return TRegisterModelResp(
@@ -85,7 +83,6 @@ class ModelManager:
             )
         except Exception as e:
             logger.warning(e)
-            self.model_storage.delete_model(req.modelId)
             return 
TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
 
     def delete_model(self, req: TDeleteModelReq) -> TSStatus:
@@ -141,8 +138,8 @@ class ModelManager:
         """
         return self.model_storage.get_ckpt_path(model_id)
 
-    def show_models(self) -> TShowModelsResp:
-        return self.model_storage.show_models()
+    def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
+        return self.model_storage.show_models(req)
 
     def register_built_in_model(self, model_info: ModelInfo):
         self.model_storage.register_built_in_model(model_info)
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 15727392e0b..5a544cc0fea 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -50,7 +50,7 @@ from ainode.core.model.model_info import (
     get_built_in_model_type,
 )
 from ainode.core.util.lock import ModelLockPool
-from ainode.thrift.ainode.ttypes import TShowModelsResp
+from ainode.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp
 from ainode.thrift.common.ttypes import TSStatus
 
 logger = Logger()
@@ -211,23 +211,30 @@ class ModelStorage(object):
             configs: TConfigs
             attributes: str
         """
-        storage_path = os.path.join(self._model_dir, f"{model_id}")
-        # create storage dir if not exist
-        if not os.path.exists(storage_path):
-            os.makedirs(storage_path)
-        model_storage_path = os.path.join(storage_path, 
DEFAULT_MODEL_FILE_NAME)
-        config_storage_path = os.path.join(storage_path, 
DEFAULT_CONFIG_FILE_NAME)
-        configs, attributes = fetch_model_by_uri(
-            uri, model_storage_path, config_storage_path
-        )
-        model_info = ModelInfo(
-            model_id=model_id,
-            model_type="",
-            category=ModelCategory.USER_DEFINED,
-            state=ModelStates.ACTIVE,
-        )
-        self.register_built_in_model(model_info)
-        return configs, attributes
+        with self._lock_pool.get_lock(model_id).write_lock():
+            storage_path = os.path.join(self._model_dir, f"{model_id}")
+            # create storage dir if not exist
+            if not os.path.exists(storage_path):
+                os.makedirs(storage_path)
+            model_storage_path = os.path.join(storage_path, 
DEFAULT_MODEL_FILE_NAME)
+            config_storage_path = os.path.join(storage_path, 
DEFAULT_CONFIG_FILE_NAME)
+            self._model_info_map[model_id] = ModelInfo(
+                model_id=model_id,
+                model_type="",
+                category=ModelCategory.USER_DEFINED,
+                state=ModelStates.LOADING,
+            )
+            try:
+                # TODO: The uri should be fetched asynchronously
+                configs, attributes = fetch_model_by_uri(
+                    uri, model_storage_path, config_storage_path
+                )
+                self._model_info_map[model_id].state = ModelStates.ACTIVE
+                return configs, attributes
+            except Exception as e:
+                logger.error(f"Failed to register model {model_id}: {e}")
+                self._model_info_map[model_id].state = ModelStates.INACTIVE
+                raise e
 
     def delete_model(self, model_id: str) -> None:
         """
@@ -241,9 +248,12 @@ class ModelStorage(object):
             if self._is_built_in(model_id):
                 raise BuiltInModelDeletionError(model_id)
 
-        # delete the user-defined model
-        storage_path = os.path.join(self._model_dir, f"{model_id}")
+        # delete the user-defined or fine-tuned model
         with self._lock_pool.get_lock(model_id).write_lock():
+            storage_path = os.path.join(self._model_dir, f"{model_id}")
+            if os.path.exists(storage_path):
+                shutil.rmtree(storage_path)
+            storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
             if os.path.exists(storage_path):
                 shutil.rmtree(storage_path)
             if model_id in self._model_info_map:
@@ -260,6 +270,21 @@ class ModelStorage(object):
         Returns:
             bool: True if the model is built-in, False otherwise.
         """
+        return (
+            model_id in self._model_info_map
+            and self._model_info_map[model_id].category == 
ModelCategory.BUILT_IN
+        )
+
+    def _is_built_in_or_fine_tuned(self, model_id: str) -> bool:
+        """
+        Check if the model_id corresponds to a built-in or fine-tuned model.
+
+        Args:
+            model_id (str): The ID of the model.
+
+        Returns:
+            bool: True if the model is built-in or fine_tuned, False otherwise.
+        """
         return model_id in self._model_info_map and (
             self._model_info_map[model_id].category == ModelCategory.BUILT_IN
             or self._model_info_map[model_id].category == 
ModelCategory.FINE_TUNED
@@ -275,7 +300,7 @@ class ModelStorage(object):
             model: The model instance corresponding to specific model_id
         """
         with self._lock_pool.get_lock(model_id).read_lock():
-            if self._is_built_in(model_id):
+            if self._is_built_in_or_fine_tuned(model_id):
                 model_dir = os.path.join(self._builtin_model_dir, 
f"{model_id}")
                 return fetch_built_in_model(
                     
get_built_in_model_type(self._model_info_map[model_id].model_type),
@@ -312,7 +337,7 @@ class ModelStorage(object):
             Whether saving succeeded
         """
         with self._lock_pool.get_lock(model_id).write_lock():
-            if self._is_built_in(model_id):
+            if self._is_built_in_or_fine_tuned(model_id):
                 model_dir = os.path.join(self._builtin_model_dir, 
f"{model_id}")
                 model.save_pretrained(model_dir)
             else:
@@ -343,12 +368,31 @@ class ModelStorage(object):
         # Only support built-in models for now
         return os.path.join(self._builtin_model_dir, f"{model_id}")
 
-    def show_models(self) -> TShowModelsResp:
+    def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
+        resp_status = TSStatus(
+            code=TSStatusCode.SUCCESS_STATUS.value,
+            message="Show models successfully",
+        )
+        if req.modelId:
+            if req.modelId in self._model_info_map:
+                model_info = self._model_info_map[req.modelId]
+                return TShowModelsResp(
+                    status=resp_status,
+                    modelIdList=[req.modelId],
+                    modelTypeMap={req.modelId: model_info.model_type},
+                    categoryMap={req.modelId: model_info.category.value},
+                    stateMap={req.modelId: model_info.state.value},
+                )
+            else:
+                return TShowModelsResp(
+                    status=resp_status,
+                    modelIdList=[],
+                    modelTypeMap={},
+                    categoryMap={},
+                    stateMap={},
+                )
         return TShowModelsResp(
-            status=TSStatus(
-                code=TSStatusCode.SUCCESS_STATUS.value,
-                message="Show models successfully",
-            ),
+            status=resp_status,
             modelIdList=list(self._model_info_map.keys()),
             modelTypeMap=dict(
                 (model_id, model_info.model_type)
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index e8ee032d0e5..ad1a6bed4ed 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -19,6 +19,7 @@
 
 package org.apache.iotdb.confignode.manager;
 
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
 import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
 import org.apache.iotdb.common.rpc.thrift.TEndPoint;
 import org.apache.iotdb.common.rpc.thrift.TSStatus;
@@ -105,7 +106,11 @@ public class ModelManager {
         new TEndPoint(registeredAINode.getInternalAddress(), 
registeredAINode.getInternalPort());
     try (AINodeClient client =
         AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
-      TShowModelsResp resp = client.showModels();
+      TShowModelsReq showModelsReq = new TShowModelsReq();
+      if (req.isSetModelId()) {
+        showModelsReq.setModelId(req.getModelId());
+      }
+      TShowModelsResp resp = client.showModels(showModelsReq);
       TShowModelResp res =
           new TShowModelResp().setStatus(new 
TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
       res.setModelIdList(resp.getModelIdList());
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
index 5a8f8325403..23e02ea2e1d 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
@@ -19,6 +19,7 @@
 
 package org.apache.iotdb.confignode.procedure.impl.model;
 
+import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
 import org.apache.iotdb.common.rpc.thrift.TSStatus;
 import org.apache.iotdb.commons.client.ainode.AINodeClient;
 import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
@@ -101,33 +102,35 @@ public class DropModelProcedure extends 
AbstractNodeProcedure<DropModelState> {
   private void dropModelOnAINode(ConfigNodeProcedureEnv env) {
     LOGGER.info("Start to drop model file [{}] on AI Node", modelName);
 
-    List<Integer> nodeIds =
-        
env.getConfigManager().getModelManager().getModelDistributions(modelName);
-    for (Integer nodeId : nodeIds) {
-      try (AINodeClient client =
-          AINodeClientManager.getInstance()
-              .borrowClient(
-                  env.getConfigManager()
-                      .getNodeManager()
-                      .getRegisteredAINode(nodeId)
-                      .getLocation()
-                      .getInternalEndPoint())) {
-        TSStatus status = client.deleteModel(modelName);
-        if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
-          LOGGER.warn(
-              "Failed to drop model [{}] on AINode [{}], status: {}",
-              modelName,
-              nodeId,
-              status.getMessage());
-        }
-      } catch (Exception e) {
-        LOGGER.warn(
-            "Failed to drop model [{}] on AINode [{}], status: {}",
-            modelName,
-            nodeId,
-            e.getMessage());
-      }
-    }
+    List<TAINodeConfiguration> aiNodes =
+        env.getConfigManager().getNodeManager().getRegisteredAINodes();
+    aiNodes.forEach(
+        aiNode -> {
+          int nodeId = aiNode.getLocation().getAiNodeId();
+          try (AINodeClient client =
+              AINodeClientManager.getInstance()
+                  .borrowClient(
+                      env.getConfigManager()
+                          .getNodeManager()
+                          .getRegisteredAINode(nodeId)
+                          .getLocation()
+                          .getInternalEndPoint())) {
+            TSStatus status = client.deleteModel(modelName);
+            if (status.getCode() != 
TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
+              LOGGER.warn(
+                  "Failed to drop model [{}] on AINode [{}], status: {}",
+                  modelName,
+                  nodeId,
+                  status.getMessage());
+            }
+          } catch (Exception e) {
+            LOGGER.warn(
+                "Failed to drop model [{}] on AINode [{}], status: {}",
+                modelName,
+                nodeId,
+                e.getMessage());
+          }
+        });
   }
 
   private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) {
diff --git 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index f573ea03b6d..e52310d1505 100644
--- 
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++ 
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -28,6 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
 import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
 import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
 import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
 import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
 import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
 import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
@@ -159,9 +160,9 @@ public class AINodeClient implements AutoCloseable, 
ThriftClient {
     }
   }
 
-  public TShowModelsResp showModels() throws TException {
+  public TShowModelsResp showModels(TShowModelsReq req) throws TException {
     try {
-      return client.showModels();
+      return client.showModels(req);
     } catch (TException e) {
       logger.warn(
           "Failed to connect to AINode from ConfigNode when executing {}: {}",
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index df6cf5daca3..a4ccef7e752 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -97,6 +97,10 @@ struct TForecastResp {
   2: required binary forecastResult
 }
 
+struct TShowModelsReq {
+  1: optional string modelId
+}
+
 struct TShowModelsResp {
   1: required common.TSStatus status
   2: optional list<string> modelIdList
@@ -108,7 +112,7 @@ struct TShowModelsResp {
 service IAINodeRPCService {
 
   // -------------- For Config Node --------------
-  TShowModelsResp showModels()
+  TShowModelsResp showModels(TShowModelsReq req)
 
   common.TSStatus deleteModel(TDeleteModelReq req)
 

Reply via email to