This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new b04cd468c5a [AINode] Fix built-in model inference & support user 
parameters (#15868)
b04cd468c5a is described below

commit b04cd468c5a6714ca629c9f493f196fd1cf25705
Author: Leo <[email protected]>
AuthorDate: Thu Jul 3 19:14:58 2025 +0800

    [AINode] Fix built-in model inference & support user parameters (#15868)
---
 .../ainode/core/manager/inference_manager.py       | 27 +++++++++++-----------
 .../ainode/ainode/core/manager/model_manager.py    | 10 +++++---
 .../ainode/core/model/built_in_model_factory.py    |  8 +++++--
 .../ainode/ainode/core/model/model_storage.py      |  5 +++-
 4 files changed, 30 insertions(+), 20 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index ae55725eaf4..9d092c78c7c 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -84,7 +84,7 @@ class SundialStrategy(InferenceStrategy):
 
 
 class BuiltInStrategy(InferenceStrategy):
-    def infer(self, full_data, **_):
+    def infer(self, full_data):
         data = pd.DataFrame(full_data[1]).T
         output = self.model.inference(data)
         df = pd.DataFrame(output)
@@ -121,20 +121,19 @@ class RegisteredStrategy(InferenceStrategy):
         return [convert_to_binary(df) for df in results]
 
 
-def _get_strategy(model_id, model):
-    if isinstance(model, TimerForPrediction):
-        return TimerXLStrategy(model)
-    if isinstance(model, SundialForPrediction):
-        return SundialStrategy(model)
-    if model_id.startswith("_"):
-        return BuiltInStrategy(model)
-    return RegisteredStrategy(model)
-
-
 class InferenceManager:
     def __init__(self, model_manager: ModelManager):
         self.model_manager = model_manager
 
+    def _get_strategy(self, model_id, model):
+        if isinstance(model, TimerForPrediction):
+            return TimerXLStrategy(model)
+        if isinstance(model, SundialForPrediction):
+            return SundialStrategy(model)
+        if self.model_manager.model_storage._is_built_in(model_id):
+            return BuiltInStrategy(model)
+        return RegisteredStrategy(model)
+
     def _run(
         self,
         req,
@@ -156,11 +155,11 @@ class InferenceManager:
 
             # load model
             accel = str(inference_attrs.get("acceleration", "")).lower() == 
"true"
-            model = self.model_manager.load_model(model_id, accel)
+            model = self.model_manager.load_model(model_id, inference_attrs, 
accel)
 
             # inference by strategy
-            strategy = _get_strategy(model_id, model)
-            outputs = strategy.infer(full_data, **inference_attrs)
+            strategy = self._get_strategy(model_id, model)
+            outputs = strategy.infer(full_data)
 
             # construct response
             status = get_status(TSStatusCode.SUCCESS_STATUS)
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py 
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index c63d0d30c97..4688edf04c3 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 #
-from typing import Callable
+from typing import Callable, Dict
 
 from torch import nn
 from yaml import YAMLError
@@ -97,13 +97,17 @@ class ModelManager:
             logger.warning(e)
             return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
 
-    def load_model(self, model_id: str, acceleration: bool = False) -> 
Callable:
+    def load_model(
+        self, model_id: str, inference_attrs: Dict[str, str], acceleration: 
bool = False
+    ) -> Callable:
         """
         Load the model with the given model_id.
         """
         logger.info(f"Load model {model_id}")
         try:
-            model = self.model_storage.load_model(model_id, acceleration)
+            model = self.model_storage.load_model(
+                model_id, inference_attrs, acceleration
+            )
             logger.info(f"Model {model_id} loaded")
             return model
         except Exception as e:
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 3c58916d908..b822357158b 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -123,7 +123,9 @@ def get_model_attributes(model_type: BuiltInModelType):
     return attribute_map
 
 
-def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
+def fetch_built_in_model(
+    model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
+) -> Callable:
     """
     Fetch the built-in model according to its id and directory, not that this 
directory only contains model weights and config.
     Args:
@@ -132,7 +134,9 @@ def fetch_built_in_model(model_type: BuiltInModelType, 
model_dir) -> Callable:
     Returns:
         model: the built-in model
     """
-    attributes = get_model_attributes(model_type)
+    default_attributes = get_model_attributes(model_type)
+    # parse the attributes from inference_attrs
+    attributes = parse_attribute(inference_attrs, default_attributes)
 
     # build the built-in model
     if model_type == BuiltInModelType.ARIMA:
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py 
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 76899d7cb48..3d745e7ce3b 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -262,7 +262,9 @@ class ModelStorage(object):
             or self._model_info_map[model_id].category == 
ModelCategory.FINE_TUNED
         )
 
-    def load_model(self, model_id: str, acceleration: bool) -> Callable:
+    def load_model(
+        self, model_id: str, inference_attrs: Dict[str, str], acceleration: 
bool
+    ) -> Callable:
         """
         Load a model with automatic detection of .safetensors or .pt format
 
@@ -275,6 +277,7 @@ class ModelStorage(object):
                 return fetch_built_in_model(
                     
get_built_in_model_type(self._model_info_map[model_id].model_type),
                     model_dir,
+                    inference_attrs,
                 )
             else:
                 # load the user-defined model

Reply via email to