This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new be7a1de2b45 [AINode] Fix bug that AINode cannot register and invoke 
user_defined model (#15849)
be7a1de2b45 is described below

commit be7a1de2b45d04c80ab5bf74bd62222ffe1e7f5c
Author: Leo <[email protected]>
AuthorDate: Wed Jul 2 12:37:33 2025 +0800

    [AINode] Fix bug that AINode cannot register and invoke user_defined model 
(#15849)
---
 .../ainode/ainode/core/model/model_storage.py      | 79 +++++++++++++++++++---
 1 file changed, 71 insertions(+), 8 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 1c9daa78e3c..76899d7cb48 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -23,6 +23,7 @@ import shutil
 from collections.abc import Callable
 from typing import Dict
 
+import torch
 from torch import nn
 
 from ainode.core.config import AINodeDescriptor
@@ -32,6 +33,7 @@ from ainode.core.constant import (
     MODEL_CONFIG_FILE_IN_JSON,
     TSStatusCode,
 )
+from ainode.core.exception import ModelNotExistError
 from ainode.core.log import Logger
 from ainode.core.model.built_in_model_factory import (
     download_ltsm_if_necessary,
@@ -104,7 +106,10 @@ class ModelStorage(object):
             future.add_done_callback(
                 lambda f, mid=model_id: 
self._callback_model_download_result(f, mid)
             )
-        # TODO: retrieve user-defined models
+        # 4. retrieve user-defined models from the model directory
+        user_defined_models = self._retrieve_user_defined_models()
+        for model_id in user_defined_models:
+            self._model_info_map[model_id] = user_defined_models[model_id]
 
     def _retrieve_fine_tuned_models(self):
         """
@@ -174,6 +179,28 @@ class ModelStorage(object):
             else:
                 self._model_info_map[model_id].state = ModelStates.INACTIVE
 
+    def _retrieve_user_defined_models(self):
+        """
+        Retrieve user_defined models from the model directory.
+
+        Returns:
+            {"model_id": ModelInfo}
+        """
+        result = {}
+        user_dirs = [
+            d
+            for d in os.listdir(self._model_dir)
+            if os.path.isdir(os.path.join(self._model_dir, d)) and d != 
"weights"
+        ]
+        for model_id in user_dirs:
+            result[model_id] = ModelInfo(
+                model_id=model_id,
+                model_type="",
+                category=ModelCategory.USER_DEFINED,
+                state=ModelStates.ACTIVE,
+            )
+        return result
+
     def register_model(self, model_id: str, uri: str):
         """
         Args:
@@ -190,7 +217,16 @@ class ModelStorage(object):
             os.makedirs(storage_path)
         model_storage_path = os.path.join(storage_path, 
DEFAULT_MODEL_FILE_NAME)
         config_storage_path = os.path.join(storage_path, 
DEFAULT_CONFIG_FILE_NAME)
-        return fetch_model_by_uri(uri, model_storage_path, config_storage_path)
+        configs, attributes = fetch_model_by_uri(
+            uri, model_storage_path, config_storage_path
+        )
+        self._model_info_map[model_id] = ModelInfo(
+            model_id=model_id,
+            model_type="",
+            category=ModelCategory.USER_DEFINED,
+            state=ModelStates.ACTIVE,
+        )
+        return configs, attributes
 
     def delete_model(self, model_id: str) -> None:
         """
@@ -241,9 +277,26 @@ class ModelStorage(object):
                     model_dir,
                 )
             else:
-                # TODO: support load the user-defined model
-                # model_dir = os.path.join(self._model_dir, f"{model_id}")
-                raise NotImplementedError
+                # load the user-defined model
+                model_dir = os.path.join(self._model_dir, f"{model_id}")
+                model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
+
+                if not os.path.exists(model_path):
+                    raise ModelNotExistError(model_path)
+                model = torch.jit.load(model_path)
+                if (
+                    isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
+                    or not acceleration
+                ):
+                    return model
+
+                try:
+                    model = torch.compile(model)
+                except Exception as e:
+                    logger.warning(
+                        f"acceleration failed, fallback to normal mode: 
{str(e)}"
+                    )
+                return model
 
     def save_model(self, model_id: str, model: nn.Module):
         """
@@ -257,9 +310,19 @@ class ModelStorage(object):
                 model_dir = os.path.join(self._builtin_model_dir, 
f"{model_id}")
                 model.save_pretrained(model_dir)
             else:
-                # TODO: support save the user-defined model
-                # model_dir = os.path.join(self._model_dir, f"{model_id}")
-                raise NotImplementedError
+                # save the user-defined model
+                model_dir = os.path.join(self._model_dir, f"{model_id}")
+                os.makedirs(model_dir, exist_ok=True)
+                model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
+                try:
+                    scripted_model = (
+                        model
+                        if isinstance(model, torch.jit.ScriptModule)
+                        else torch.jit.script(model)
+                    )
+                    torch.jit.save(scripted_model, model_path)
+                except Exception as e:
+                    logger.error(f"Failed to save scripted model: {e}")
 
     def get_ckpt_path(self, model_id: str) -> str:
         """

Reply via email to