This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new be7a1de2b45 [AINode] Fix bug that AINode cannot register and invoke
user_defined model (#15849)
be7a1de2b45 is described below
commit be7a1de2b45d04c80ab5bf74bd62222ffe1e7f5c
Author: Leo <[email protected]>
AuthorDate: Wed Jul 2 12:37:33 2025 +0800
[AINode] Fix bug that AINode cannot register and invoke user_defined model
(#15849)
---
.../ainode/ainode/core/model/model_storage.py | 79 +++++++++++++++++++---
1 file changed, 71 insertions(+), 8 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 1c9daa78e3c..76899d7cb48 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -23,6 +23,7 @@ import shutil
from collections.abc import Callable
from typing import Dict
+import torch
from torch import nn
from ainode.core.config import AINodeDescriptor
@@ -32,6 +33,7 @@ from ainode.core.constant import (
MODEL_CONFIG_FILE_IN_JSON,
TSStatusCode,
)
+from ainode.core.exception import ModelNotExistError
from ainode.core.log import Logger
from ainode.core.model.built_in_model_factory import (
download_ltsm_if_necessary,
@@ -104,7 +106,10 @@ class ModelStorage(object):
future.add_done_callback(
lambda f, mid=model_id:
self._callback_model_download_result(f, mid)
)
- # TODO: retrieve user-defined models
+ # 4. retrieve user-defined models from the model directory
+ user_defined_models = self._retrieve_user_defined_models()
+ for model_id in user_defined_models:
+ self._model_info_map[model_id] = user_defined_models[model_id]
def _retrieve_fine_tuned_models(self):
"""
@@ -174,6 +179,28 @@ class ModelStorage(object):
else:
self._model_info_map[model_id].state = ModelStates.INACTIVE
+ def _retrieve_user_defined_models(self):
+ """
+ Retrieve user_defined models from the model directory.
+
+ Returns:
+ {"model_id": ModelInfo}
+ """
+ result = {}
+ user_dirs = [
+ d
+ for d in os.listdir(self._model_dir)
+ if os.path.isdir(os.path.join(self._model_dir, d)) and d !=
"weights"
+ ]
+ for model_id in user_dirs:
+ result[model_id] = ModelInfo(
+ model_id=model_id,
+ model_type="",
+ category=ModelCategory.USER_DEFINED,
+ state=ModelStates.ACTIVE,
+ )
+ return result
+
def register_model(self, model_id: str, uri: str):
"""
Args:
@@ -190,7 +217,16 @@ class ModelStorage(object):
os.makedirs(storage_path)
model_storage_path = os.path.join(storage_path,
DEFAULT_MODEL_FILE_NAME)
config_storage_path = os.path.join(storage_path,
DEFAULT_CONFIG_FILE_NAME)
- return fetch_model_by_uri(uri, model_storage_path, config_storage_path)
+ configs, attributes = fetch_model_by_uri(
+ uri, model_storage_path, config_storage_path
+ )
+ self._model_info_map[model_id] = ModelInfo(
+ model_id=model_id,
+ model_type="",
+ category=ModelCategory.USER_DEFINED,
+ state=ModelStates.ACTIVE,
+ )
+ return configs, attributes
def delete_model(self, model_id: str) -> None:
"""
@@ -241,9 +277,26 @@ class ModelStorage(object):
model_dir,
)
else:
- # TODO: support load the user-defined model
- # model_dir = os.path.join(self._model_dir, f"{model_id}")
- raise NotImplementedError
+ # load the user-defined model
+ model_dir = os.path.join(self._model_dir, f"{model_id}")
+ model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
+
+ if not os.path.exists(model_path):
+ raise ModelNotExistError(model_path)
+ model = torch.jit.load(model_path)
+ if (
+ isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
+ or not acceleration
+ ):
+ return model
+
+ try:
+ model = torch.compile(model)
+ except Exception as e:
+ logger.warning(
+ f"acceleration failed, fallback to normal mode:
{str(e)}"
+ )
+ return model
def save_model(self, model_id: str, model: nn.Module):
"""
@@ -257,9 +310,19 @@ class ModelStorage(object):
model_dir = os.path.join(self._builtin_model_dir,
f"{model_id}")
model.save_pretrained(model_dir)
else:
- # TODO: support save the user-defined model
- # model_dir = os.path.join(self._model_dir, f"{model_id}")
- raise NotImplementedError
+ # save the user-defined model
+ model_dir = os.path.join(self._model_dir, f"{model_id}")
+ os.makedirs(model_dir, exist_ok=True)
+ model_path = os.path.join(model_dir, DEFAULT_MODEL_FILE_NAME)
+ try:
+ scripted_model = (
+ model
+ if isinstance(model, torch.jit.ScriptModule)
+ else torch.jit.script(model)
+ )
+ torch.jit.save(scripted_model, model_path)
+ except Exception as e:
+ logger.error(f"Failed to save scripted model: {e}")
def get_ckpt_path(self, model_id: str) -> str:
"""