This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new f1870cb9b06 [AINode] Limit max inference length (#15982)
f1870cb9b06 is described below
commit f1870cb9b06c2fe9cedaa2a3586fd706e256791d
Author: Yongzao <[email protected]>
AuthorDate: Wed Jul 23 12:05:42 2025 +0800
[AINode] Limit max inference length (#15982)
---
iotdb-core/ainode/ainode/core/ainode.py | 2 +-
iotdb-core/ainode/ainode/core/config.py | 12 ++++++++++++
iotdb-core/ainode/ainode/core/constant.py | 1 +
.../ainode/core/inference/inference_request_pool.py | 5 +++--
iotdb-core/ainode/ainode/core/inference/utils.py | 2 +-
.../ainode/ainode/core/manager/inference_manager.py | 21 ++++++++++++++++++---
iotdb-core/ainode/ainode/core/rpc/handler.py | 6 +++---
7 files changed, 39 insertions(+), 10 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/ainode.py
b/iotdb-core/ainode/ainode/core/ainode.py
index d18ac21464e..82c6f6988e2 100644
--- a/iotdb-core/ainode/ainode/core/ainode.py
+++ b/iotdb-core/ainode/ainode/core/ainode.py
@@ -134,7 +134,7 @@ class AINode:
raise e
# Start the RPC service
- self._rpc_handler = AINodeRPCServiceHandler(aiNode=self)
+ self._rpc_handler = AINodeRPCServiceHandler(ainode=self)
self._rpc_service = AINodeRPCService(self._rpc_handler)
self._rpc_service.start()
self._rpc_service.join(1)
diff --git a/iotdb-core/ainode/ainode/core/config.py
b/iotdb-core/ainode/ainode/core/config.py
index d347379570b..5126d0e53e3 100644
--- a/iotdb-core/ainode/ainode/core/config.py
+++ b/iotdb-core/ainode/ainode/core/config.py
@@ -31,6 +31,7 @@ from ainode.core.constant import (
AINODE_CONF_GIT_FILE_NAME,
AINODE_CONF_POM_FILE_NAME,
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
+ AINODE_INFERENCE_MAX_PREDICT_LENGTH,
AINODE_LOG_DIR,
AINODE_MODELS_DIR,
AINODE_ROOT_CONF_DIRECTORY_NAME,
@@ -72,6 +73,9 @@ class AINodeConfig(object):
self._ain_inference_batch_interval_in_ms: int = (
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
)
+ self._ain_inference_max_predict_length: int = (
+ AINODE_INFERENCE_MAX_PREDICT_LENGTH
+ )
# log directory
self._ain_logs_dir: str = AINODE_LOG_DIR
@@ -140,6 +144,14 @@ class AINodeConfig(object):
) -> None:
self._ain_inference_batch_interval_in_ms =
ain_inference_batch_interval_in_ms
+ def get_ain_inference_max_predict_length(self) -> int:
+ return self._ain_inference_max_predict_length
+
+ def set_ain_inference_max_predict_length(
+ self, ain_inference_max_predict_length: int
+ ) -> None:
+ self._ain_inference_max_predict_length =
ain_inference_max_predict_length
+
def get_ain_logs_dir(self) -> str:
return self._ain_logs_dir
diff --git a/iotdb-core/ainode/ainode/core/constant.py
b/iotdb-core/ainode/ainode/core/constant.py
index bd414ce1253..bd5646b3513 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -50,6 +50,7 @@ DEFAULT_RECONNECT_TIMES = 3
# AINode inference configuration
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
+AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
# AINode folder structure
AINODE_ROOT_DIR = os.path.dirname(
diff --git a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
index ce40bd16859..d0cac2760ab 100644
--- a/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/ainode/core/inference/inference_request_pool.py
@@ -23,7 +23,7 @@ import time
import numpy as np
import torch
import torch.multiprocessing as mp
-from transformers import PretrainedConfig, PreTrainedModel
+from transformers import PretrainedConfig
from ainode.core.config import AINodeDescriptor
from ainode.core.inference.inference_request import InferenceRequest
@@ -46,7 +46,7 @@ class InferenceRequestPool(mp.Process):
def __init__(
self,
pool_id: int,
- model_id: int,
+ model_id: str,
config: PretrainedConfig,
request_queue: mp.Queue,
result_queue: mp.Queue,
@@ -58,6 +58,7 @@ class InferenceRequestPool(mp.Process):
self.config = config
self.pool_kwargs = pool_kwargs
self.model = None
+ self._model_manager = None
self.device = None
# TODO: A scheduler is necessary for better handling following queues
diff --git a/iotdb-core/ainode/ainode/core/inference/utils.py
b/iotdb-core/ainode/ainode/core/inference/utils.py
index c2a618d716c..cf10b5b2cd4 100644
--- a/iotdb-core/ainode/ainode/core/inference/utils.py
+++ b/iotdb-core/ainode/ainode/core/inference/utils.py
@@ -22,7 +22,7 @@ import torch
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
-def _generate_req_id(length=10, charset=string.ascii_letters + string.digits)
-> str:
+def generate_req_id(length=10, charset=string.ascii_letters + string.digits)
-> str:
"""
Generate a random req_id string of specified length.
The length is 10 by default, with 10^{17} possible combinations.
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index bed39c48d08..5a853ac4e72 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -30,6 +30,7 @@ from ainode.core.constant import TSStatusCode
from ainode.core.exception import (
InferenceModelInternalError,
InvalidWindowArgumentError,
+ NumericalRangeException,
runtime_error_extractor,
)
from ainode.core.inference.inference_request import (
@@ -40,7 +41,7 @@ from ainode.core.inference.inference_request_pool import
InferenceRequestPool
from ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
TimerSundialInferencePipeline,
)
-from ainode.core.inference.utils import _generate_req_id
+from ainode.core.inference.utils import generate_req_id
from ainode.core.log import Logger
from ainode.core.manager.model_manager import ModelManager
from ainode.core.model.sundial.configuration_sundial import SundialConfig
@@ -214,6 +215,20 @@ class InferenceManager:
full_data = deserializer(raw)
inference_attrs = extract_attrs(req)
+ predict_length = inference_attrs.get("predict_length", 96)
+ if (
+ predict_length
+ >
AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
+ ):
+ raise NumericalRangeException(
+ "output_length",
+ 1,
+ AINodeDescriptor()
+ .get_config()
+ .get_ain_inference_max_predict_length(),
+ predict_length,
+ )
+
if model_id == self.ACCELERATE_MODEL_ID and self.DEFAULT_POOL_SIZE
> 0:
# TODO: Logic in this branch shall handle all LTSM inferences
# TODO: TSBlock -> Tensor codes should be unified
@@ -223,10 +238,10 @@ class InferenceManager:
# the inputs should be on CPU before passing to the inference
request
inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
infer_req = InferenceRequest(
- req_id=_generate_req_id(),
+ req_id=generate_req_id(),
inputs=inputs,
inference_pipeline=TimerSundialInferencePipeline(SundialConfig()),
- max_new_tokens=inference_attrs.get("predict_length", 96),
+ max_new_tokens=predict_length,
)
infer_proxy = InferenceRequestProxy(infer_req.req_id)
with self._result_wrapper_lock:
diff --git a/iotdb-core/ainode/ainode/core/rpc/handler.py
b/iotdb-core/ainode/ainode/core/rpc/handler.py
index cb25420ae00..d3948020ab3 100644
--- a/iotdb-core/ainode/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/ainode/core/rpc/handler.py
@@ -41,13 +41,13 @@ logger = Logger()
class AINodeRPCServiceHandler(IAINodeRPCService.Iface):
- def __init__(self, aiNode):
- self._aiNode = aiNode
+ def __init__(self, ainode):
+ self._ainode = ainode
self._model_manager = ModelManager()
self._inference_manager = InferenceManager()
def stopAINode(self) -> TSStatus:
- self._aiNode.stop()
+ self._ainode.stop()
return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped
successfully.")
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: