This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new b04cd468c5a [AINode] Fix built-in model inference & support user
parameters (#15868)
b04cd468c5a is described below
commit b04cd468c5a6714ca629c9f493f196fd1cf25705
Author: Leo <[email protected]>
AuthorDate: Thu Jul 3 19:14:58 2025 +0800
[AINode] Fix built-in model inference & support user parameters (#15868)
---
.../ainode/core/manager/inference_manager.py | 27 +++++++++++-----------
.../ainode/ainode/core/manager/model_manager.py | 10 +++++---
.../ainode/core/model/built_in_model_factory.py | 8 +++++--
.../ainode/ainode/core/model/model_storage.py | 5 +++-
4 files changed, 30 insertions(+), 20 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index ae55725eaf4..9d092c78c7c 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -84,7 +84,7 @@ class SundialStrategy(InferenceStrategy):
class BuiltInStrategy(InferenceStrategy):
- def infer(self, full_data, **_):
+ def infer(self, full_data):
data = pd.DataFrame(full_data[1]).T
output = self.model.inference(data)
df = pd.DataFrame(output)
@@ -121,20 +121,19 @@ class RegisteredStrategy(InferenceStrategy):
return [convert_to_binary(df) for df in results]
-def _get_strategy(model_id, model):
- if isinstance(model, TimerForPrediction):
- return TimerXLStrategy(model)
- if isinstance(model, SundialForPrediction):
- return SundialStrategy(model)
- if model_id.startswith("_"):
- return BuiltInStrategy(model)
- return RegisteredStrategy(model)
-
-
class InferenceManager:
def __init__(self, model_manager: ModelManager):
self.model_manager = model_manager
+ def _get_strategy(self, model_id, model):
+ if isinstance(model, TimerForPrediction):
+ return TimerXLStrategy(model)
+ if isinstance(model, SundialForPrediction):
+ return SundialStrategy(model)
+ if self.model_manager.model_storage._is_built_in(model_id):
+ return BuiltInStrategy(model)
+ return RegisteredStrategy(model)
+
def _run(
self,
req,
@@ -156,11 +155,11 @@ class InferenceManager:
# load model
accel = str(inference_attrs.get("acceleration", "")).lower() ==
"true"
- model = self.model_manager.load_model(model_id, accel)
+ model = self.model_manager.load_model(model_id, inference_attrs,
accel)
# inference by strategy
- strategy = _get_strategy(model_id, model)
- outputs = strategy.infer(full_data, **inference_attrs)
+ strategy = self._get_strategy(model_id, model)
+ outputs = strategy.infer(full_data)
# construct response
status = get_status(TSStatusCode.SUCCESS_STATUS)
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index c63d0d30c97..4688edf04c3 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Callable
+from typing import Callable, Dict
from torch import nn
from yaml import YAMLError
@@ -97,13 +97,17 @@ class ModelManager:
logger.warning(e)
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
- def load_model(self, model_id: str, acceleration: bool = False) ->
Callable:
+ def load_model(
+ self, model_id: str, inference_attrs: Dict[str, str], acceleration:
bool = False
+ ) -> Callable:
"""
Load the model with the given model_id.
"""
logger.info(f"Load model {model_id}")
try:
- model = self.model_storage.load_model(model_id, acceleration)
+ model = self.model_storage.load_model(
+ model_id, inference_attrs, acceleration
+ )
logger.info(f"Model {model_id} loaded")
return model
except Exception as e:
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 3c58916d908..b822357158b 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -123,7 +123,9 @@ def get_model_attributes(model_type: BuiltInModelType):
return attribute_map
-def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
+def fetch_built_in_model(
+ model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
+) -> Callable:
"""
Fetch the built-in model according to its id and directory, not that this
directory only contains model weights and config.
Args:
@@ -132,7 +134,9 @@ def fetch_built_in_model(model_type: BuiltInModelType,
model_dir) -> Callable:
Returns:
model: the built-in model
"""
- attributes = get_model_attributes(model_type)
+ default_attributes = get_model_attributes(model_type)
+ # parse the attributes from inference_attrs
+ attributes = parse_attribute(inference_attrs, default_attributes)
# build the built-in model
if model_type == BuiltInModelType.ARIMA:
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 76899d7cb48..3d745e7ce3b 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -262,7 +262,9 @@ class ModelStorage(object):
or self._model_info_map[model_id].category ==
ModelCategory.FINE_TUNED
)
- def load_model(self, model_id: str, acceleration: bool) -> Callable:
+ def load_model(
+ self, model_id: str, inference_attrs: Dict[str, str], acceleration:
bool
+ ) -> Callable:
"""
Load a model with automatic detection of .safetensors or .pt format
@@ -275,6 +277,7 @@ class ModelStorage(object):
return fetch_built_in_model(
get_built_in_model_type(self._model_info_map[model_id].model_type),
model_dir,
+ inference_attrs,
)
else:
# load the user-defined model