This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch ain-bug-fix
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/ain-bug-fix by this push:
new 90ac5618711 finish
90ac5618711 is described below
commit 90ac56187113927cfe290aa5799d66fa92e84700
Author: Yongzao <[email protected]>
AuthorDate: Sat Jun 28 02:57:10 2025 +0800
finish
---
.../ainode/ainode/core/manager/model_manager.py | 13 ++++++++
.../ainode/ainode/core/model/model_storage.py | 38 ++++++++++++++++++++--
.../iotdb/confignode/persistence/ModelInfo.java | 9 ++---
.../operator/process/ai/InferenceOperator.java | 3 +-
.../queryengine/plan/analyze/AnalyzeVisitor.java | 7 ----
.../iotdb/commons/model/ModelInformation.java | 2 +-
6 files changed, 53 insertions(+), 19 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index 9b965f0c711..46177378035 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -26,6 +26,7 @@ from ainode.core.exception import (
InvalidUriError,
)
from ainode.core.log import Logger
+from ainode.core.model.model_info import BuiltInModelType, ModelInfo,
ModelStates
from ainode.core.model.model_storage import ModelStorage
from ainode.core.util.status import get_status
from ainode.thrift.ainode.ttypes import (
@@ -140,3 +141,15 @@ class ModelManager:
def show_models(self) -> TShowModelsResp:
return self.model_storage.show_models()
+
+ def register_built_in_model(self, model_info: ModelInfo):
+ self.model_storage.register_built_in_model(model_info)
+
+ def update_model_state(self, model_id: str, state: ModelStates):
+ self.model_storage.update_model_state(model_id, state)
+
+ def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
+ """
+ Get the type of the model with the given model_id.
+ """
+ return self.model_storage.get_built_in_model_type(model_id.lower())
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index 2682a139e6c..92f2528845c 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -207,6 +207,9 @@ class ModelStorage(object):
with self._lock_pool.get_lock(model_id).write_lock():
if os.path.exists(storage_path):
shutil.rmtree(storage_path)
+ if model_id in self._model_info_map:
+ del self._model_info_map[model_id]
+ logger.info(f"Model {model_id} deleted successfully.")
def _is_built_in(self, model_id: str) -> bool:
"""
@@ -218,9 +221,9 @@ class ModelStorage(object):
Returns:
bool: True if the model is built-in, False otherwise.
"""
- return (
- model_id in self._model_info_map
- and self._model_info_map[model_id].category ==
ModelCategory.BUILT_IN
+ return model_id in self._model_info_map and (
+ self._model_info_map[model_id].category ==
ModelCategory.BUILT_IN
+ or self._model_info_map[model_id].category ==
ModelCategory.FINE_TUNED
)
def load_model(self, model_id: str, acceleration: bool) -> Callable:
@@ -291,3 +294,32 @@ class ModelStorage(object):
for model_id, model_info in self._model_info_map.items()
),
)
+
+ def register_built_in_model(self, model_info: ModelInfo):
+ with self._lock_pool.get_lock(model_info.model_id).write_lock():
+ self._model_info_map[model_info.model_id] = model_info
+
+ def update_model_state(self, model_id: str, state: ModelStates):
+ with self._lock_pool.get_lock(model_id).write_lock():
+ if model_id in self._model_info_map:
+ self._model_info_map[model_id].state = state
+ else:
+ raise ValueError(f"Model {model_id} does not exist.")
+
+ def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
+ """
+ Get the type of the model with the given model_id.
+
+ Args:
+ model_id (str): The ID of the model.
+
+ Returns:
+ str: The type of the model.
+ """
+ with self._lock_pool.get_lock(model_id).read_lock():
+ if model_id in self._model_info_map:
+ return get_built_in_model_type(
+ self._model_info_map[model_id].model_type
+ )
+ else:
+ raise ValueError(f"Model {model_id} does not exist.")
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index e96f6fb7bf6..7f0eb6b4e88 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -121,13 +121,8 @@ public class ModelInfo implements SnapshotProcessor {
try {
acquireModelTableWriteLock();
String modelName = plan.getModelName();
- if (modelTable.containsModel(modelName)) {
- return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
- .setMessage(String.format("model [%s] has already been created.",
modelName));
- } else {
- modelTable.addModel(new ModelInformation(modelName,
ModelStatus.LOADING));
- return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
- }
+ modelTable.addModel(new ModelInformation(modelName,
ModelStatus.LOADING));
+ return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
} catch (Exception e) {
final String errorMessage =
String.format(
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
index a384be3ad24..fd51ced46e8 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java
@@ -254,7 +254,8 @@ public class InferenceOperator implements ProcessOperator {
}
private TsBlock preProcess(TsBlock inputTsBlock) {
- boolean notBuiltIn =
!modelInferenceDescriptor.getModelInformation().isBuiltIn();
+ // boolean notBuiltIn =
!modelInferenceDescriptor.getModelInformation().isBuiltIn();
+ boolean notBuiltIn = false;
if (windowType == null || windowType == InferenceWindowType.HEAD) {
if (notBuiltIn
&& totalRow !=
modelInferenceDescriptor.getModelInformation().getInputShape()[0]) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
index b71a9770126..a1033250f68 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java
@@ -481,13 +481,6 @@ public class AnalyzeVisitor extends
StatementVisitor<Analysis, MPPQueryContext>
if (modelInformation.isBuiltIn()) {
return;
}
-
- if (modelInformation.getInputShape()[0] != windowSize) {
- throw new SemanticException(
- String.format(
- "Window output %d is not equal to input size of model %d",
- windowSize, modelInformation.getInputShape()[0]));
- }
}
private ISchemaTree analyzeSchema(
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
index 9e84c92a311..3fa10768543 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java
@@ -84,7 +84,7 @@ public class ModelInformation {
}
public ModelInformation(String modelName, ModelStatus status) {
- this.modelType = ModelType.USER_DEFINED;
+ this.modelType = ModelType.BUILT_IN_FORECAST;
this.modelName = modelName;
this.inputShape = new int[0];
this.outputShape = new int[0];