This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 719c504d410 [AINode] Refactor the result columns of "show models" SQL
(#15831)
719c504d410 is described below
commit 719c504d4107cc5b5de084a63a3f14ebbe6b5711
Author: Li Yu Heng <[email protected]>
AuthorDate: Fri Jun 27 22:19:53 2025 +0800
[AINode] Refactor the result columns of "show models" SQL (#15831)
* column refactored and model can be auto downloaded
* Update AINodeBasicIT.java
---------
Co-authored-by: Yongzao <[email protected]>
---
.../org/apache/iotdb/ainode/it/AINodeBasicIT.java | 6 +-
.../java/org/apache/iotdb/rpc/TSStatusCode.java | 1 +
iotdb-core/ainode/ainode/core/constant.py | 45 +----
iotdb-core/ainode/ainode/core/handler.py | 5 +-
.../ainode/core/manager/inference_manager.py | 6 +-
.../ainode/ainode/core/manager/model_manager.py | 20 +--
.../ainode/core/model/built_in_model_factory.py | 128 +++++++-------
iotdb-core/ainode/ainode/core/model/model_info.py | 160 ++++++++++++++++++
.../ainode/ainode/core/model/model_storage.py | 183 ++++++++++++++++++---
.../consensus/response/model/ModelTableResp.java | 13 +-
.../iotdb/confignode/manager/ConfigManager.java | 2 +-
.../iotdb/confignode/manager/ModelManager.java | 53 +++---
.../iotdb/confignode/persistence/ModelInfo.java | 20 +--
.../InformationSchemaContentSupplierFactory.java | 121 ++++++++++----
.../config/executor/ClusterConfigTaskExecutor.java | 2 +-
.../config/metadata/ai/ShowModelsTask.java | 72 ++++----
.../iotdb/commons/client/ainode/AINodeClient.java | 13 ++
.../schema/column/ColumnHeaderConstant.java | 5 +-
.../thrift-ainode/src/main/thrift/ainode.thrift | 9 +
.../src/main/thrift/confignode.thrift | 5 +-
20 files changed, 614 insertions(+), 255 deletions(-)
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
index da4f8e9c536..84cb3576cee 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeBasicIT.java
@@ -21,8 +21,6 @@ package org.apache.iotdb.ainode.it;
import org.apache.iotdb.it.env.EnvFactory;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
@@ -68,14 +66,14 @@ public class AINodeBasicIT {
"insert into root.AI.data(timestamp,s0,s1,s2,s3)
values(7,7.0,8.0,9.0,10.0)",
};
- @BeforeClass
+ // @BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1M cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
prepareData(sqls);
}
- @AfterClass
+ // @AfterClass
public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}
diff --git
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
index 6359eafde5c..34385f7ddc2 100644
---
a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
+++
b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java
@@ -198,6 +198,7 @@ public enum TSStatusCode {
CAN_NOT_CONNECT_CONFIGNODE(1010),
CAN_NOT_CONNECT_AINODE(1011),
NO_AVAILABLE_REPLICA(1012),
+ NO_AVAILABLE_AINODE(1013),
// Sync, Load TsFile
LOAD_FILE_ERROR(1100),
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index a2ce549ef2b..c0e021ceebf 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -62,6 +62,8 @@ AINODE_LOG_FILE_LEVELS = [logging.DEBUG, logging.INFO,
logging.WARNING, logging.
TRIAL_ID_PREFIX = "__trial_"
DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0"
+MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
+MODEL_CONFIG_FILE_IN_JSON = "config.json"
DEFAULT_MODEL_FILE_NAME = "model.pt"
DEFAULT_CONFIG_FILE_NAME = "config.yaml"
DEFAULT_CHUNK_SIZE = 8192
@@ -72,12 +74,6 @@ DEFAULT_RECONNECT_TIMES = 3
STD_LEVEL = logging.INFO
-TIMER_REPO_ID = {
- "_timerxl": "thuml/timer-base-84m",
- "_sundial": "thuml/sundial-base-128m",
-}
-
-
class TSStatusCode(Enum):
SUCCESS_STATUS = 200
REDIRECTION_RECOMMEND = 400
@@ -151,43 +147,6 @@ class ModelInputName(Enum):
DEC_INP = "dec_inp"
-class BuiltInModelType(Enum):
- # forecast models
- ARIMA = "_arima"
- HOLTWINTERS = "_holtwinters"
- EXPONENTIAL_SMOOTHING = "_exponentialsmoothing"
- NAIVE_FORECASTER = "_naiveforecaster"
- STL_FORECASTER = "_stlforecaster"
-
- # anomaly detection models
- GAUSSIAN_HMM = "_gaussianhmm"
- GMM_HMM = "_gmmhmm"
- STRAY = "_stray"
-
- # timerxl
- TIMER_XL = "_timerxl"
-
- # sundial
- SUNDIAL = "_sundial"
-
- @classmethod
- def values(cls) -> List[str]:
- values = []
- for item in list(cls):
- values.append(item.value)
- return values
-
- @staticmethod
- def is_built_in_model(model_id: str) -> bool:
- """
- Check if the model ID corresponds to a built-in model.
- """
- # TODO: Unify this ugly hard code
- if "timerxl" in model_id or "sundial" in model_id:
- return True
- return model_id in BuiltInModelType.values()
-
-
class AttributeName(Enum):
# forecast Attribute
PREDICT_LENGTH = "predict_length"
diff --git a/iotdb-core/ainode/ainode/core/handler.py
b/iotdb-core/ainode/ainode/core/handler.py
index b338286e97b..804be637934 100644
--- a/iotdb-core/ainode/ainode/core/handler.py
+++ b/iotdb-core/ainode/ainode/core/handler.py
@@ -16,7 +16,6 @@
# under the License.
#
-from ainode.core.constant import TSStatusCode
from ainode.core.log import Logger
from ainode.core.manager.cluster_manager import ClusterManager
from ainode.core.manager.inference_manager import InferenceManager
@@ -31,6 +30,7 @@ from ainode.thrift.ainode.ttypes import (
TInferenceResp,
TRegisterModelReq,
TRegisterModelResp,
+ TShowModelsResp,
TTrainingReq,
)
from ainode.thrift.common.ttypes import TSStatus
@@ -58,5 +58,8 @@ class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
return ClusterManager.get_heart_beat(req)
+ def showModels(self) -> TShowModelsResp:
+ return self._model_manager.show_models()
+
def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
pass
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index dcbe8bd7359..d685970758a 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -21,7 +21,7 @@ import pandas as pd
import torch
from iotdb.tsfile.utils.tsblock_serde import deserialize
-from ainode.core.constant import BuiltInModelType, TSStatusCode
+from ainode.core.constant import TSStatusCode
from ainode.core.exception import (
InferenceModelInternalError,
InvalidWindowArgumentError,
@@ -150,9 +150,7 @@ class InferenceManager:
# load model
accel = str(inference_attrs.get("acceleration", "")).lower() ==
"true"
- model = self.model_manager.load_model(
- model_id, BuiltInModelType.is_built_in_model(model_id), accel
- )
+ model = self.model_manager.load_model(model_id, accel)
# inference by strategy
strategy = _get_strategy(model_id, model)
diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py
b/iotdb-core/ainode/ainode/core/manager/model_manager.py
index b039f924023..9b965f0c711 100644
--- a/iotdb-core/ainode/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py
@@ -32,6 +32,7 @@ from ainode.thrift.ainode.ttypes import (
TDeleteModelReq,
TRegisterModelReq,
TRegisterModelResp,
+ TShowModelsResp,
)
from ainode.thrift.common.ttypes import TSStatus
@@ -95,32 +96,28 @@ class ModelManager:
logger.warning(e)
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
- def load_model(
- self, model_id: str, is_built_in: bool, acceleration: bool = False
- ) -> Callable:
+ def load_model(self, model_id: str, acceleration: bool = False) ->
Callable:
"""
Load the model with the given model_id.
"""
+ model_id = model_id.lower()
logger.info(f"Load model {model_id}")
try:
- model = self.model_storage.load_model(
- model_id.lower(), is_built_in, acceleration
- )
+ model = self.model_storage.load_model(model_id, acceleration)
logger.info(f"Model {model_id} loaded")
return model
except Exception as e:
logger.error(f"Failed to load model {model_id}: {e}")
raise
- def save_model(
- self, model_id: str, is_built_in: bool, model: nn.Module
- ) -> TSStatus:
+ def save_model(self, model_id: str, model: nn.Module) -> TSStatus:
"""
Save the model using save_pretrained
"""
+ model_id = model_id.lower()
logger.info(f"Saving model {model_id}")
try:
- self.model_storage.save_model(model_id, is_built_in, model)
+ self.model_storage.save_model(model_id, model)
logger.info(f"Saving model {model_id} successfully")
return get_status(
TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved
successfully"
@@ -140,3 +137,6 @@ class ModelManager:
str: The path to the checkpoint file for the model.
"""
return self.model_storage.get_ckpt_path(model_id)
+
+ def show_models(self) -> TShowModelsResp:
+ return self.model_storage.show_models()
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 806ff0ac204..3c58916d908 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -30,7 +30,11 @@ from sktime.forecasting.naive import NaiveForecaster
from sktime.forecasting.trend import STLForecaster
from ainode.core.config import AINodeDescriptor
-from ainode.core.constant import TIMER_REPO_ID, AttributeName, BuiltInModelType
+from ainode.core.constant import (
+ MODEL_CONFIG_FILE_IN_JSON,
+ MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ AttributeName,
+)
from ainode.core.exception import (
BuiltInModelNotSupportError,
InferenceModelInternalError,
@@ -40,108 +44,120 @@ from ainode.core.exception import (
WrongAttributeTypeError,
)
from ainode.core.log import Logger
+from ainode.core.model.model_info import TIMER_REPO_ID, BuiltInModelType
from ainode.core.model.sundial import modeling_sundial
from ainode.core.model.timerxl import modeling_timer
logger = Logger()
-def download_built_in_model_if_necessary(model_id: str, local_dir):
+def download_ltsm_if_necessary(model_type: BuiltInModelType, local_dir) ->
bool:
"""
- Download the built-in model from HuggingFace repository when necessary.
+ Download the built-in ltsm from HuggingFace repository when necessary.
+
+ Return:
+ bool: True if the model is existed or downloaded successfully, False
otherwise.
"""
- if "timer" in model_id or "sundial" in model_id:
- weights_path = os.path.join(local_dir, "model.safetensors")
- if not os.path.exists(weights_path):
- logger.info(
- f"Weight not found at {weights_path}, downloading from
HuggingFace..."
+ repo_id = TIMER_REPO_ID[model_type]
+ weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
+ if not os.path.exists(weights_path):
+ logger.info(
+ f"Weight not found at {weights_path}, downloading from
HuggingFace..."
+ )
+ try:
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ local_dir=local_dir,
)
- repo_id = TIMER_REPO_ID[model_id]
- try:
- hf_hub_download(
- repo_id=repo_id,
- filename="model.safetensors",
- local_dir=local_dir,
- )
- logger.info(f"Got weight to {weights_path}")
- config_path = os.path.join(local_dir, "config.json")
- hf_hub_download(
- repo_id=repo_id,
- filename="config.json",
- local_dir=local_dir,
- )
- logger.info(f"Got config to {config_path}")
- except Exception as e:
- logger.error(
- f"Failed to download huggingface model to {local_dir} due
to {e}"
- )
- raise e
+ logger.info(f"Got weight to {weights_path}")
+ except Exception as e:
+ logger.error(
+ f"Failed to download huggingface model weights to {local_dir}
due to {e}"
+ )
+ return False
+ config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
+ if not os.path.exists(config_path):
+ logger.info(
+ f"Config not found at {config_path}, downloading from
HuggingFace..."
+ )
+ try:
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=MODEL_CONFIG_FILE_IN_JSON,
+ local_dir=local_dir,
+ )
+ logger.info(f"Got config to {config_path}")
+ except Exception as e:
+ logger.error(
+ f"Failed to download huggingface model config to {local_dir}
due to {e}"
+ )
+ return False
+ return True
-def get_model_attributes(model_id: str):
- if model_id == BuiltInModelType.ARIMA.value:
+def get_model_attributes(model_type: BuiltInModelType):
+ if model_type == BuiltInModelType.ARIMA:
attribute_map = arima_attribute_map
- elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+ elif model_type == BuiltInModelType.NAIVE_FORECASTER:
attribute_map = naive_forecaster_attribute_map
elif (
- model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value
- or model_id == BuiltInModelType.HOLTWINTERS.value
+ model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
+ or model_type == BuiltInModelType.HOLTWINTERS.value
):
attribute_map = exponential_smoothing_attribute_map
- elif model_id == BuiltInModelType.STL_FORECASTER.value:
+ elif model_type == BuiltInModelType.STL_FORECASTER:
attribute_map = stl_forecaster_attribute_map
- elif model_id == BuiltInModelType.GMM_HMM.value:
+ elif model_type == BuiltInModelType.GMM_HMM:
attribute_map = gmmhmm_attribute_map
- elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+ elif model_type == BuiltInModelType.GAUSSIAN_HMM:
attribute_map = gaussian_hmm_attribute_map
- elif model_id == BuiltInModelType.STRAY.value:
+ elif model_type == BuiltInModelType.STRAY:
attribute_map = stray_attribute_map
- # TODO: The model type should be judged before enter this file
- elif "timerxl" in model_id:
+ elif model_type == BuiltInModelType.TIMER_XL:
attribute_map = timerxl_attribute_map
- elif "sundial" in model_id:
+ elif model_type == BuiltInModelType.SUNDIAL:
attribute_map = sundial_attribute_map
else:
- raise BuiltInModelNotSupportError(model_id)
+ raise BuiltInModelNotSupportError(model_type.value)
return attribute_map
-def fetch_built_in_model(model_id: str, model_dir) -> Callable:
+def fetch_built_in_model(model_type: BuiltInModelType, model_dir) -> Callable:
"""
Fetch the built-in model according to its id and directory, not that this
directory only contains model weights and config.
Args:
- model_id: the unique id of the model
+ model_type: the type of the built-in model
model_dir: for huggingface models only, the directory where the model
is stored
Returns:
model: the built-in model
"""
- attributes = get_model_attributes(model_id)
+ attributes = get_model_attributes(model_type)
# build the built-in model
- if model_id == BuiltInModelType.ARIMA.value:
+ if model_type == BuiltInModelType.ARIMA:
model = ArimaModel(attributes)
elif (
- model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value
- or model_id == BuiltInModelType.HOLTWINTERS.value
+ model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
+ or model_type == BuiltInModelType.HOLTWINTERS
):
model = ExponentialSmoothingModel(attributes)
- elif model_id == BuiltInModelType.NAIVE_FORECASTER.value:
+ elif model_type == BuiltInModelType.NAIVE_FORECASTER:
model = NaiveForecasterModel(attributes)
- elif model_id == BuiltInModelType.STL_FORECASTER.value:
+ elif model_type == BuiltInModelType.STL_FORECASTER:
model = STLForecasterModel(attributes)
- elif model_id == BuiltInModelType.GMM_HMM.value:
+ elif model_type == BuiltInModelType.GMM_HMM:
model = GMMHMMModel(attributes)
- elif model_id == BuiltInModelType.GAUSSIAN_HMM.value:
+ elif model_type == BuiltInModelType.GAUSSIAN_HMM:
model = GaussianHmmModel(attributes)
- elif model_id == BuiltInModelType.STRAY.value:
+ elif model_type == BuiltInModelType.STRAY:
model = STRAYModel(attributes)
- # TODO: The model type should be judged before enter this file
- elif "timerxl" in model_id:
+ elif model_type == BuiltInModelType.TIMER_XL:
model = modeling_timer.TimerForPrediction.from_pretrained(model_dir)
- elif "sundial" in model_id:
+ elif model_type == BuiltInModelType.SUNDIAL:
model =
modeling_sundial.SundialForPrediction.from_pretrained(model_dir)
else:
- raise BuiltInModelNotSupportError(model_id)
+ raise BuiltInModelNotSupportError(model_type.value)
return model
diff --git a/iotdb-core/ainode/ainode/core/model/model_info.py
b/iotdb-core/ainode/ainode/core/model/model_info.py
new file mode 100644
index 00000000000..7fd41cba763
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/model_info.py
@@ -0,0 +1,160 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from enum import Enum
+from typing import List
+
+
+class BuiltInModelType(Enum):
+ # forecast models
+ ARIMA = "Arima"
+ HOLTWINTERS = "HoltWinters"
+ EXPONENTIAL_SMOOTHING = "ExponentialSmoothing"
+ NAIVE_FORECASTER = "NaiveForecaster"
+ STL_FORECASTER = "StlForecaster"
+
+ # anomaly detection models
+ GAUSSIAN_HMM = "GaussianHmm"
+ GMM_HMM = "GmmHmm"
+ STRAY = "Stray"
+
+ # large time series models (LTSM)
+ TIMER_XL = "Timer-XL"
+ # sundial
+ SUNDIAL = "Timer-Sundial"
+
+ @classmethod
+ def values(cls) -> List[str]:
+ return [item.value for item in cls]
+
+ @staticmethod
+ def is_built_in_model(model_type: str) -> bool:
+ """
+ Check if the given model type corresponds to a built-in model.
+ """
+ return model_type in BuiltInModelType.values()
+
+
+def get_built_in_model_type(model_type: str) -> BuiltInModelType:
+ if not BuiltInModelType.is_built_in_model(model_type):
+ raise ValueError(f"Invalid built-in model type: {model_type}")
+ return BuiltInModelType(model_type)
+
+
+class ModelCategory(Enum):
+ BUILT_IN = "BUILT-IN"
+ FINE_TUNED = "FINE-TUNED"
+ USER_DEFINED = "USER-DEFINED"
+
+
+class ModelStates(Enum):
+ ACTIVE = "ACTIVE"
+ INACTIVE = "INACTIVE"
+ LOADING = "LOADING"
+ DROPPING = "DROPPING"
+ TRAINING = "TRAINING"
+ FAILED = "FAILED"
+
+
+class ModelInfo:
+ def __init__(
+ self,
+ model_id: str,
+ model_type: str,
+ category: ModelCategory,
+ state: ModelStates,
+ ):
+ self.model_id = model_id
+ self.model_type = model_type
+ self.category = category
+ self.state = state
+
+
+TIMER_REPO_ID = {
+ BuiltInModelType.TIMER_XL: "thuml/timer-base-84m",
+ BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m",
+}
+
+# Built-in machine learning models, they can be employed directly
+BUILT_IN_MACHINE_LEARNING_MODEL_MAP = {
+ # forecast models
+ "arima": ModelInfo(
+ model_id="arima",
+ model_type=BuiltInModelType.ARIMA.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "holtwinters": ModelInfo(
+ model_id="holtwinters",
+ model_type=BuiltInModelType.HOLTWINTERS.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "exponential_smoothing": ModelInfo(
+ model_id="exponential_smoothing",
+ model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "naive_forecaster": ModelInfo(
+ model_id="naive_forecaster",
+ model_type=BuiltInModelType.NAIVE_FORECASTER.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "stl_forecaster": ModelInfo(
+ model_id="stl_forecaster",
+ model_type=BuiltInModelType.STL_FORECASTER.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ # anomaly detection models
+ "gaussian_hmm": ModelInfo(
+ model_id="gaussian_hmm",
+ model_type=BuiltInModelType.GAUSSIAN_HMM.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "gmm_hmm": ModelInfo(
+ model_id="gmm_hmm",
+ model_type=BuiltInModelType.GMM_HMM.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+ "stray": ModelInfo(
+ model_id="stray",
+ model_type=BuiltInModelType.STRAY.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.ACTIVE,
+ ),
+}
+
+# Built-in large time series models (LTSM), their weights are not included in
AINode by default
+BUILT_IN_LTSM_MAP = {
+ "timer_xl": ModelInfo(
+ model_id="timer_xl",
+ model_type=BuiltInModelType.TIMER_XL.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.LOADING,
+ ),
+ "sundial": ModelInfo(
+ model_id="sundial",
+ model_type=BuiltInModelType.SUNDIAL.value,
+ category=ModelCategory.BUILT_IN,
+ state=ModelStates.LOADING,
+ ),
+}
diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py
b/iotdb-core/ainode/ainode/core/model/model_storage.py
index b8cf4a82bec..2682a139e6c 100644
--- a/iotdb-core/ainode/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/ainode/core/model/model_storage.py
@@ -16,27 +16,40 @@
# under the License.
#
+import concurrent.futures
+import json
import os
import shutil
from collections.abc import Callable
+from typing import Dict
-from pylru import lrucache
from torch import nn
from ainode.core.config import AINodeDescriptor
from ainode.core.constant import (
DEFAULT_CONFIG_FILE_NAME,
DEFAULT_MODEL_FILE_NAME,
- BuiltInModelType,
+ MODEL_CONFIG_FILE_IN_JSON,
+ TSStatusCode,
)
-from ainode.core.exception import BuiltInModelNotSupportError
from ainode.core.log import Logger
from ainode.core.model.built_in_model_factory import (
- download_built_in_model_if_necessary,
+ download_ltsm_if_necessary,
fetch_built_in_model,
)
from ainode.core.model.model_factory import fetch_model_by_uri
+from ainode.core.model.model_info import (
+ BUILT_IN_LTSM_MAP,
+ BUILT_IN_MACHINE_LEARNING_MODEL_MAP,
+ BuiltInModelType,
+ ModelCategory,
+ ModelInfo,
+ ModelStates,
+ get_built_in_model_type,
+)
from ainode.core.util.lock import ModelLockPool
+from ainode.thrift.ainode.ttypes import TShowModelsResp
+from ainode.thrift.common.ttypes import TSStatus
logger = Logger()
@@ -62,9 +75,104 @@ class ModelStorage(object):
logger.error(e)
raise e
self._lock_pool = ModelLockPool()
- self._model_cache = lrucache(
- AINodeDescriptor().get_config().get_ain_model_storage_cache_size()
- )
+ self._executor = concurrent.futures.ThreadPoolExecutor(
+ max_workers=1
+ ) # TODO: Here we set the work_num=1 cause we found that the hf
download interface is not stable for concurrent downloading.
+ self._model_info_map: Dict[str, ModelInfo] = {}
+ self._init_model_info_map()
+
+ def _init_model_info_map(self):
+ """
+ Initialize the model info map.
+ """
+ # 1. initialize built-in and ready-to-use models
+ for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP:
+ self._model_info_map[model_id] =
BUILT_IN_MACHINE_LEARNING_MODEL_MAP[
+ model_id
+ ]
+ # 2. retrieve fine-tuned models from the built-in model directory
+ fine_tuned_models = self._retrieve_fine_tuned_models()
+ for model_id in fine_tuned_models:
+ self._model_info_map[model_id] = fine_tuned_models[model_id]
+ # 3. automatically downloading the weights of built-in LSTM models
when necessary
+ for model_id in BUILT_IN_LTSM_MAP:
+ if model_id not in self._model_info_map:
+ self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id]
+ future = self._executor.submit(
+ self._download_built_in_model_if_necessary, model_id
+ )
+ future.add_done_callback(
+ lambda f, mid=model_id:
self._callback_model_download_result(f, mid)
+ )
+ # TODO: retrieve user-defined models
+
+ def _retrieve_fine_tuned_models(self):
+ """
+ Retrieve fine-tuned models from the built-in model directory.
+
+ Returns:
+ {"model_id": ModelInfo}
+ """
+ result = {}
+ build_in_dirs = [
+ d
+ for d in os.listdir(self._builtin_model_dir)
+ if os.path.isdir(os.path.join(self._builtin_model_dir, d))
+ ]
+ for model_id in build_in_dirs:
+ config_file_path = os.path.join(
+ self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON
+ )
+ if os.path.isfile(config_file_path):
+ with open(config_file_path, "r") as f:
+ model_config = json.load(f)
+ if "model_type" in model_config:
+ model_type = model_config["model_type"]
+ model_info = ModelInfo(
+ model_id=model_id,
+ model_type=model_type,
+ category=ModelCategory.FINE_TUNED,
+ state=ModelStates.ACTIVE,
+ )
+ # Refactor the built-in model category
+ if "timer_xl" == model_id:
+ model_info.category = ModelCategory.BUILT_IN
+ if "sundial" == model_id:
+ model_info.category = ModelCategory.BUILT_IN
+ # Compatible patch with the codes in HuggingFace
+ if "timer" == model_type:
+ model_info.model_type = BuiltInModelType.TIMER_XL.value
+ if "sundial" == model_type:
+ model_info.model_type = BuiltInModelType.SUNDIAL.value
+ result[model_id] = model_info
+ return result
+
+ def _download_built_in_model_if_necessary(self, model_id: str) -> bool:
+ """
+ Download the built-in model if it is not already downloaded.
+
+ Args:
+ model_id (str): The ID of the model to download.
+
+ Return:
+ bool: True if the model is existed or downloaded successfully,
False otherwise.
+ """
+ with self._lock_pool.get_lock(model_id).write_lock():
+ local_dir = os.path.join(self._builtin_model_dir, model_id)
+ return download_ltsm_if_necessary(
+
get_built_in_model_type(self._model_info_map[model_id].model_type),
+ local_dir,
+ )
+
+ def _callback_model_download_result(self, future, model_id: str):
+ with self._lock_pool.get_lock(model_id).write_lock():
+ if future.result():
+ self._model_info_map[model_id].state = ModelStates.ACTIVE
+ logger.info(
+ f"The built-in model: {model_id} is active and ready to
use."
+ )
+ else:
+ self._model_info_map[model_id].state = ModelStates.INACTIVE
def register_model(self, model_id: str, uri: str):
"""
@@ -94,17 +202,28 @@ class ModelStorage(object):
storage_path = os.path.join(self._model_dir, f"{model_id}")
with self._lock_pool.get_lock(model_id).write_lock():
if os.path.exists(storage_path):
- for file_name in os.listdir(storage_path):
- self._remove_from_cache(os.path.join(storage_path,
file_name))
+ shutil.rmtree(storage_path)
+ storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
+ with self._lock_pool.get_lock(model_id).write_lock():
+ if os.path.exists(storage_path):
shutil.rmtree(storage_path)
- def _remove_from_cache(self, file_path: str) -> None:
- if file_path in self._model_cache:
- del self._model_cache[file_path]
+ def _is_built_in(self, model_id: str) -> bool:
+ """
+ Check if the model_id corresponds to a built-in model.
- def load_model(
- self, model_id: str, is_built_in: bool, acceleration: bool
- ) -> Callable:
+ Args:
+ model_id (str): The ID of the model.
+
+ Returns:
+ bool: True if the model is built-in, False otherwise.
+ """
+ return (
+ model_id in self._model_info_map
+ and self._model_info_map[model_id].category ==
ModelCategory.BUILT_IN
+ )
+
+ def load_model(self, model_id: str, acceleration: bool) -> Callable:
"""
Load a model with automatic detection of .safetensors or .pt format
@@ -112,17 +231,18 @@ class ModelStorage(object):
model: The model instance corresponding to specific model_id
"""
with self._lock_pool.get_lock(model_id).read_lock():
- if is_built_in:
- # For built-in models, we support auto download
+ if self._is_built_in(model_id):
model_dir = os.path.join(self._builtin_model_dir,
f"{model_id}")
- download_built_in_model_if_necessary(model_id, model_dir)
- return fetch_built_in_model(model_id, model_dir)
+ return fetch_built_in_model(
+
get_built_in_model_type(self._model_info_map[model_id].model_type),
+ model_dir,
+ )
else:
# TODO: support load the user-defined model
# model_dir = os.path.join(self._model_dir, f"{model_id}")
raise NotImplementedError
- def save_model(self, model_id: str, is_built_in: bool, model: nn.Module):
+ def save_model(self, model_id: str, model: nn.Module):
"""
Save the model using save_pretrained
@@ -130,7 +250,7 @@ class ModelStorage(object):
Whether saving succeeded
"""
with self._lock_pool.get_lock(model_id).write_lock():
- if is_built_in:
+ if self._is_built_in(model_id):
model_dir = os.path.join(self._builtin_model_dir,
f"{model_id}")
model.save_pretrained(model_dir)
else:
@@ -150,3 +270,24 @@ class ModelStorage(object):
"""
# Only support built-in models for now
return os.path.join(self._builtin_model_dir, f"{model_id}")
+
+ def show_models(self) -> TShowModelsResp:
+ return TShowModelsResp(
+ status=TSStatus(
+ code=TSStatusCode.SUCCESS_STATUS.value,
+ message="Show models successfully",
+ ),
+ modelIdList=list(self._model_info_map.keys()),
+ modelTypeMap=dict(
+ (model_id, model_info.model_type)
+ for model_id, model_info in self._model_info_map.items()
+ ),
+ categoryMap=dict(
+ (model_id, model_info.category.value)
+ for model_id, model_info in self._model_info_map.items()
+ ),
+ stateMap=dict(
+ (model_id, model_info.state.value)
+ for model_id, model_info in self._model_info_map.items()
+ ),
+ )
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
index 9a23d9ed713..7490a53a01c 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
@@ -21,18 +21,21 @@ package
org.apache.iotdb.confignode.consensus.response.model;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.consensus.common.DataSet;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
+import java.util.Map;
+// TODO: Will be removed in the future
public class ModelTableResp implements DataSet {
private final TSStatus status;
private final List<ByteBuffer> serializedAllModelInformation;
+ private Map<String, String> modelTypeMap;
+ private Map<String, String> algorithmMap;
public ModelTableResp(TSStatus status) {
this.status = status;
@@ -49,7 +52,11 @@ public class ModelTableResp implements DataSet {
this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
}
- public TShowModelResp convertToThriftResponse() throws IOException {
- return new TShowModelResp(status, serializedAllModelInformation);
+ public void setModelTypeMap(Map<String, String> modelTypeMap) {
+ this.modelTypeMap = modelTypeMap;
+ }
+
+ public void setAlgorithmMap(Map<String, String> algorithmMap) {
+ this.algorithmMap = algorithmMap;
}
}
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index a6462a72a2a..82fd508c5c1 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -2720,7 +2720,7 @@ public class ConfigManager implements IManager {
TSStatus status = confirmLeader();
return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
? modelManager.showModel(req)
- : new TShowModelResp(status, Collections.emptyList());
+ : new TShowModelResp(status);
}
@Override
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
index 2ebc4d71dfc..e8ee032d0e5 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
@@ -19,17 +19,20 @@
package org.apache.iotdb.confignode.manager;
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TSStatus;
+import org.apache.iotdb.commons.client.ainode.AINodeClient;
+import org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.model.ModelStatus;
import org.apache.iotdb.commons.model.ModelType;
import
org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import
org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
-import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
import org.apache.iotdb.confignode.persistence.ModelInfo;
+import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
@@ -37,15 +40,12 @@ import
org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
-import org.apache.iotdb.consensus.common.DataSet;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.rpc.TSStatusCode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.util.Collections;
import java.util.List;
public class ModelManager {
@@ -92,22 +92,33 @@ public class ModelManager {
}
public TShowModelResp showModel(final TShowModelReq req) {
- try {
- final DataSet response = configManager.getConsensusManager().read(new
ShowModelPlan(req));
- return ((ModelTableResp) response).convertToThriftResponse();
- } catch (final ConsensusException e) {
- LOGGER.warn(
- String.format("Unexpected error happened while showing model %s: ",
req.getModelId()), e);
- // consensus layer related errors
- final TSStatus res = new
TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
- res.setMessage(e.getMessage());
- return new TShowModelResp(res, Collections.emptyList());
- } catch (final IOException e) {
- LOGGER.warn("Fail to get ModelTable", e);
- return new TShowModelResp(
- new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
- .setMessage(e.getMessage()),
- Collections.emptyList());
+ List<TAINodeInfo> registeredAINodes =
+ configManager.getNodeManager().getRegisteredAINodeInfoList();
+ if (registeredAINodes.isEmpty()) {
+ return new TShowModelResp()
+ .setStatus(
+ new TSStatus(TSStatusCode.NO_AVAILABLE_AINODE.getStatusCode())
+ .setMessage("Show models failed due to there is no AINode
available"));
+ }
+ TAINodeInfo registeredAINode = registeredAINodes.get(0);
+ TEndPoint targetAINodeEndPoint =
+ new TEndPoint(registeredAINode.getInternalAddress(),
registeredAINode.getInternalPort());
+ try (AINodeClient client =
+ AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
+ TShowModelsResp resp = client.showModels();
+ TShowModelResp res =
+ new TShowModelResp().setStatus(new
TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
+ res.setModelIdList(resp.getModelIdList());
+ res.setModelTypeMap(resp.getModelTypeMap());
+ res.setCategoryMap(resp.getCategoryMap());
+ res.setStateMap(resp.getStateMap());
+ return res;
+ } catch (Exception e) {
+ LOGGER.warn("Failed to show models due to", e);
+ return new TShowModelResp()
+ .setStatus(
+ new TSStatus(TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode())
+ .setMessage(e.getMessage()));
}
}
diff --git
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
index fb9076626ab..e96f6fb7bf6 100644
---
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
@@ -75,16 +75,16 @@ public class ModelInfo implements SnapshotProcessor {
private static final Set<String> builtInAnomalyDetectionModel = new
HashSet<>();
static {
- builtInForecastModel.add("_ARIMA");
- builtInForecastModel.add("_NaiveForecaster");
- builtInForecastModel.add("_STLForecaster");
- builtInForecastModel.add("_HoltWinters");
- builtInForecastModel.add("_ExponentialSmoothing");
- builtInForecastModel.add("_timerxl");
- builtInForecastModel.add("_sundial");
- builtInAnomalyDetectionModel.add("_GaussianHMM");
- builtInAnomalyDetectionModel.add("_GMMHMM");
- builtInAnomalyDetectionModel.add("_Stray");
+ builtInForecastModel.add("arima");
+ builtInForecastModel.add("naive_forecaster");
+ builtInForecastModel.add("stl_forecaster");
+ builtInForecastModel.add("holtwinters");
+ builtInForecastModel.add("exponential_smoothing");
+ builtInForecastModel.add("timer_xl");
+ builtInForecastModel.add("sundial");
+ builtInAnomalyDetectionModel.add("gaussian_hmm");
+ builtInAnomalyDetectionModel.add("gmm_hmm");
+ builtInAnomalyDetectionModel.add("stray");
}
public ModelInfo() {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java
index efc8d405459..85854bdb2dd 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java
@@ -26,7 +26,6 @@ import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
import org.apache.iotdb.commons.conf.IoTDBConstant;
import org.apache.iotdb.commons.exception.auth.AccessDeniedException;
-import org.apache.iotdb.commons.model.ModelType;
import org.apache.iotdb.commons.pipe.agent.plugin.builtin.BuiltinPipePlugin;
import org.apache.iotdb.commons.pipe.agent.plugin.meta.PipePluginMeta;
import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
@@ -51,6 +50,7 @@ import
org.apache.iotdb.confignode.rpc.thrift.TNodeVersionInfo;
import org.apache.iotdb.confignode.rpc.thrift.TRegionInfo;
import org.apache.iotdb.confignode.rpc.thrift.TShowClusterResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowModelReq;
+import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.confignode.rpc.thrift.TShowPipeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TShowPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TShowRegionReq;
@@ -83,9 +83,7 @@ import
org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BytesUtils;
import org.apache.tsfile.utils.Pair;
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
@@ -116,10 +114,6 @@ import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.Sho
import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowFunctionsTask.getFunctionType;
import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowPipePluginsTask.PIPE_PLUGIN_TYPE_BUILTIN;
import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowPipePluginsTask.PIPE_PLUGIN_TYPE_EXTERNAL;
-import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask.INPUT_DATA_TYPE;
-import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask.INPUT_SHAPE;
-import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask.OUTPUT_DATA_TYPE;
-import static
org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowModelsTask.OUTPUT_SHAPE;
public class InformationSchemaContentSupplierFactory {
private InformationSchemaContentSupplierFactory() {}
@@ -778,49 +772,104 @@ public class InformationSchemaContentSupplierFactory {
}
private static class ModelsSupplier extends TsBlockSupplier {
- private Iterator<ByteBuffer> iterator;
+ private ModelIterator iterator;
private ModelsSupplier(final List<TSDataType> dataTypes) {
super(dataTypes);
try (final ConfigNodeClient client =
ConfigNodeClientManager.getInstance().borrowClient(ConfigNodeInfo.CONFIG_REGION_ID))
{
- iterator = client.showModel(new
TShowModelReq()).getModelInfoList().iterator();
+ iterator = new ModelIterator(client.showModel(new TShowModelReq()));
} catch (final Exception e) {
lastException = e;
}
}
- @Override
- protected void constructLine() {
- final ByteBuffer modelInfo = iterator.next();
- columnBuilders[0].writeBinary(
- new Binary(ReadWriteIOUtils.readString(modelInfo),
TSFileConfig.STRING_CHARSET));
+ private static class ModelIterator implements Iterator<ModelInfoInString> {
- final String modelType = ReadWriteIOUtils.readString(modelInfo);
- columnBuilders[1].writeBinary(new Binary(modelType,
TSFileConfig.STRING_CHARSET));
- columnBuilders[2].writeBinary(
- new Binary(ReadWriteIOUtils.readString(modelInfo),
TSFileConfig.STRING_CHARSET));
+ private int index = 0;
+ private final TShowModelResp resp;
- if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) {
- columnBuilders[3].writeBinary(
- new Binary(
- INPUT_SHAPE
- + ReadWriteIOUtils.readString(modelInfo)
- + OUTPUT_SHAPE
- + ReadWriteIOUtils.readString(modelInfo)
- + INPUT_DATA_TYPE
- + ReadWriteIOUtils.readString(modelInfo)
- + OUTPUT_DATA_TYPE
- + ReadWriteIOUtils.readString(modelInfo),
- TSFileConfig.STRING_CHARSET));
- columnBuilders[4].writeBinary(
- new Binary(ReadWriteIOUtils.readString(modelInfo),
TSFileConfig.STRING_CHARSET));
- } else {
- columnBuilders[3].appendNull();
- columnBuilders[4].writeBinary(
- new Binary("Built-in model in IoTDB",
TSFileConfig.STRING_CHARSET));
+ private ModelIterator(TShowModelResp resp) {
+ this.resp = resp;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return index < resp.getModelIdListSize();
+ }
+
+ @Override
+ public ModelInfoInString next() {
+ String modelId = resp.getModelIdList().get(index++);
+ return new ModelInfoInString(
+ modelId,
+ resp.getModelTypeMap().get(modelId),
+ resp.getCategoryMap().get(modelId),
+ resp.getStateMap().get(modelId));
+ }
+ }
+
+ private static class ModelInfoInString {
+
+ private final String modelId;
+ private final String modelType;
+ private final String category;
+ private final String state;
+
+ public ModelInfoInString(String modelId, String modelType, String
category, String state) {
+ this.modelId = modelId;
+ this.modelType = modelType;
+ this.category = category;
+ this.state = state;
+ }
+
+ public String getModelId() {
+ return modelId;
+ }
+
+ public String getModelType() {
+ return modelType;
+ }
+
+ public String getCategory() {
+ return category;
}
+ public String getState() {
+ return state;
+ }
+ }
+
+ @Override
+ protected void constructLine() {
+ final ModelInfoInString modelInfo = iterator.next();
+ columnBuilders[0].writeBinary(
+ new Binary(modelInfo.getModelId(), TSFileConfig.STRING_CHARSET));
+ columnBuilders[1].writeBinary(
+ new Binary(modelInfo.getModelType(), TSFileConfig.STRING_CHARSET));
+ columnBuilders[2].writeBinary(
+ new Binary(modelInfo.getCategory(), TSFileConfig.STRING_CHARSET));
+ columnBuilders[3].writeBinary(new Binary(modelInfo.getState(),
TSFileConfig.STRING_CHARSET));
+ // if (Objects.equals(modelType,
ModelType.USER_DEFINED.toString())) {
+ // columnBuilders[3].writeBinary(
+ // new Binary(
+ // INPUT_SHAPE
+ // + ReadWriteIOUtils.readString(modelInfo)
+ // + OUTPUT_SHAPE
+ // + ReadWriteIOUtils.readString(modelInfo)
+ // + INPUT_DATA_TYPE
+ // + ReadWriteIOUtils.readString(modelInfo)
+ // + OUTPUT_DATA_TYPE
+ // + ReadWriteIOUtils.readString(modelInfo),
+ // TSFileConfig.STRING_CHARSET));
+ // columnBuilders[4].writeBinary(
+ // new Binary(ReadWriteIOUtils.readString(modelInfo),
+ // TSFileConfig.STRING_CHARSET));
+ // } else {
+ // columnBuilders[3].appendNull();
+ // columnBuilders[4].writeBinary(
+ // new Binary("Built-in model in IoTDB",
TSFileConfig.STRING_CHARSET));
+ // }
resultBuilder.declarePosition();
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index f45c3ce187b..6d57f1ac12b 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -3321,7 +3321,7 @@ public class ClusterConfigTaskExecutor implements
IConfigTaskExecutor {
return future;
}
// convert model info list and buildTsBlock
- ShowModelsTask.buildTsBlock(showModelResp.getModelInfoList(), future);
+ ShowModelsTask.buildTsBlock(showModelResp, future);
} catch (final ClientManagerException | TException e) {
future.setException(e);
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
index 78ded3da6e9..73c5b1b6f7d 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/ShowModelsTask.java
@@ -19,9 +19,9 @@
package org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai;
-import org.apache.iotdb.commons.model.ModelType;
import org.apache.iotdb.commons.schema.column.ColumnHeader;
import org.apache.iotdb.commons.schema.column.ColumnHeaderConstant;
+import org.apache.iotdb.confignode.rpc.thrift.TShowModelResp;
import org.apache.iotdb.db.queryengine.common.header.DatasetHeader;
import org.apache.iotdb.db.queryengine.common.header.DatasetHeaderFactory;
import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult;
@@ -34,11 +34,9 @@ import com.google.common.util.concurrent.SettableFuture;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.utils.BytesUtils;
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-import java.nio.ByteBuffer;
import java.util.List;
-import java.util.Objects;
+import java.util.Map;
import java.util.stream.Collectors;
public class ShowModelsTask implements IConfigTask {
@@ -61,49 +59,43 @@ public class ShowModelsTask implements IConfigTask {
return configTaskExecutor.showModels(modelName);
}
- public static void buildTsBlock(
- List<ByteBuffer> modelInfoList, SettableFuture<ConfigTaskResult> future)
{
+ public static void buildTsBlock(TShowModelResp resp,
SettableFuture<ConfigTaskResult> future) {
+ List<String> modelIdList = resp.getModelIdList();
+ Map<String, String> modelTypeMap = resp.getModelTypeMap();
+ Map<String, String> categoryMap = resp.getCategoryMap();
+ Map<String, String> stateMap = resp.getStateMap();
List<TSDataType> outputDataTypes =
ColumnHeaderConstant.showModelsColumnHeaders.stream()
.map(ColumnHeader::getColumnType)
.collect(Collectors.toList());
TsBlockBuilder builder = new TsBlockBuilder(outputDataTypes);
- for (ByteBuffer modelInfo : modelInfoList) {
- String modelId = ReadWriteIOUtils.readString(modelInfo);
- String modelType = ReadWriteIOUtils.readString(modelInfo);
- String state = ReadWriteIOUtils.readString(modelInfo);
- String note;
- String config;
- if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) {
- String inputShape = ReadWriteIOUtils.readString(modelInfo);
- String outputShape = ReadWriteIOUtils.readString(modelInfo);
- String inputTypes = ReadWriteIOUtils.readString(modelInfo);
- String outputTypes = ReadWriteIOUtils.readString(modelInfo);
- note = ReadWriteIOUtils.readString(modelInfo);
- config =
- INPUT_SHAPE
- + inputShape
- + OUTPUT_SHAPE
- + outputShape
- + INPUT_DATA_TYPE
- + inputTypes
- + OUTPUT_DATA_TYPE
- + outputTypes;
- } else {
- config = EMPTY_STRING;
- note = "Built-in model in IoTDB";
- }
-
+ for (String modelId : modelIdList) {
+ // String note;
+ // String config;
+ // if (Objects.equals(modelType,
ModelType.USER_DEFINED.toString())) {
+ // String inputShape = ReadWriteIOUtils.readString(modelInfo);
+ // String outputShape = ReadWriteIOUtils.readString(modelInfo);
+ // String inputTypes = ReadWriteIOUtils.readString(modelInfo);
+ // String outputTypes = ReadWriteIOUtils.readString(modelInfo);
+ // note = ReadWriteIOUtils.readString(modelInfo);
+ // config =
+ // INPUT_SHAPE
+ // + inputShape
+ // + OUTPUT_SHAPE
+ // + outputShape
+ // + INPUT_DATA_TYPE
+ // + inputTypes
+ // + OUTPUT_DATA_TYPE
+ // + outputTypes;
+ // } else {
+ // config = EMPTY_STRING;
+ // note = "Built-in model in IoTDB";
+ // }
builder.getTimeColumnBuilder().writeLong(0L);
builder.getColumnBuilder(0).writeBinary(BytesUtils.valueOf(modelId));
- builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(modelType));
- builder.getColumnBuilder(2).writeBinary(BytesUtils.valueOf(state));
- builder.getColumnBuilder(3).writeBinary(BytesUtils.valueOf(config));
- if (note != null) {
- builder.getColumnBuilder(4).writeBinary(BytesUtils.valueOf(note));
- } else {
- builder.getColumnBuilder(4).writeBinary(BytesUtils.valueOf(""));
- }
+
builder.getColumnBuilder(1).writeBinary(BytesUtils.valueOf(modelTypeMap.get(modelId)));
+
builder.getColumnBuilder(2).writeBinary(BytesUtils.valueOf(categoryMap.get(modelId)));
+
builder.getColumnBuilder(3).writeBinary(BytesUtils.valueOf(stateMap.get(modelId)));
builder.declarePosition();
}
DatasetHeader datasetHeader = DatasetHeaderFactory.getShowModelsHeader();
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
index 346a459136a..f573ea03b6d 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ainode/AINodeClient.java
@@ -28,6 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq;
import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq;
import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp;
+import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
import org.apache.iotdb.ainode.rpc.thrift.TWindowParams;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
@@ -158,6 +159,18 @@ public class AINodeClient implements AutoCloseable,
ThriftClient {
}
}
+ public TShowModelsResp showModels() throws TException {
+ try {
+ return client.showModels();
+ } catch (TException e) {
+ logger.warn(
+ "Failed to connect to AINode from ConfigNode when executing {}: {}",
+ Thread.currentThread().getStackTrace()[1].getMethodName(),
+ e.getMessage());
+ throw new TException(MSG_CONNECTION_FAIL);
+ }
+ }
+
public TInferenceResp inference(
String modelId,
TsBlock inputTsBlock,
diff --git
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
index 72fe1ce0236..a12fd9ceb00 100644
---
a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
+++
b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java
@@ -612,9 +612,8 @@ public class ColumnHeaderConstant {
ImmutableList.of(
new ColumnHeader(MODEL_ID, TSDataType.TEXT),
new ColumnHeader(MODEL_TYPE, TSDataType.TEXT),
- new ColumnHeader(STATE, TSDataType.TEXT),
- new ColumnHeader(CONFIGS, TSDataType.TEXT),
- new ColumnHeader(NOTES, TSDataType.TEXT));
+ new ColumnHeader(COLUMN_CATEGORY, TSDataType.TEXT),
+ new ColumnHeader(STATE, TSDataType.TEXT));
public static final List<ColumnHeader> showLogicalViewColumnHeaders =
ImmutableList.of(
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index db1a15c2460..52e665a532e 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -98,9 +98,18 @@ struct TForecastResp {
2: required binary forecastResult
}
+struct TShowModelsResp {
+ 1: required common.TSStatus status
+ 2: optional list<string> modelIdList
+ 3: optional map<string, string> modelTypeMap
+ 4: optional map<string, string> categoryMap
+ 5: optional map<string, string> stateMap
+}
+
service IAINodeRPCService {
// -------------- For Config Node --------------
+ TShowModelsResp showModels()
common.TSStatus deleteModel(TDeleteModelReq req)
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index e5d599ab632..59effaccae8 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1063,7 +1063,10 @@ struct TShowModelReq {
struct TShowModelResp {
1: required common.TSStatus status
- 2: required list<binary> modelInfoList
+ 2: optional list<string> modelIdList
+ 3: optional map<string, string> modelTypeMap
+ 4: optional map<string, string> categoryMap
+ 5: optional map<string, string> stateMap
}
struct TGetModelInfoReq {