This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 333b385a08a [AINode] Fix bugs and sync code (#17032)
333b385a08a is described below
commit 333b385a08a0c341dd33cd19bff7f0c76db66228
Author: Leo <[email protected]>
AuthorDate: Sun Jan 18 17:42:51 2026 +0800
[AINode] Fix bugs and sync code (#17032)
---
iotdb-core/ainode/iotdb/ainode/core/constant.py | 2 +-
.../ainode/core/inference/batcher/basic_batcher.py | 27 ++++++++++++++++------
.../ainode/core/inference/inference_request.py | 15 ++++++------
.../core/inference/inference_request_pool.py | 2 +-
.../pool_scheduler/basic_pool_scheduler.py | 11 ++++++++-
.../iotdb/ainode/core/manager/inference_manager.py | 2 +-
.../ainode/iotdb/ainode/core/manager/utils.py | 4 +---
.../function/tvf/ClassifyTableFunction.java | 2 +-
.../function/tvf/ForecastTableFunction.java | 2 +-
.../db/queryengine/plan/udf/UDTFForecast.java | 3 +--
.../thrift-ainode/src/main/thrift/ainode.thrift | 2 +-
11 files changed, 45 insertions(+), 27 deletions(-)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py
b/iotdb-core/ainode/iotdb/ainode/core/constant.py
index 8a83c981437..68f64a79afc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/constant.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py
@@ -53,7 +53,7 @@ AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880
# TODO: Should be optimized
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
"sundial": 1036 * 1024**2, # 1036 MiB
- "timer": 856 * 1024**2, # 856 MiB
+ "timer_xl": 856 * 1024**2, # 856 MiB
} # the memory usage of each model in bytes
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2 # the device space allocated for
inference
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
index 591a0d7c1dd..80ff683d9d1 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
@@ -34,6 +34,7 @@ class BasicBatcher(AbstractBatcher):
Args:
"""
+ super().__init__()
def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor:
"""
@@ -46,17 +47,29 @@ class BasicBatcher(AbstractBatcher):
Returns:
torch.Tensor: Concatenated input tensor of shape
- [sum(req.batch_size), length].
+ [sum(req.batch_size), target_count, input_length].
"""
if not reqs:
raise ValueError("No requests provided to batch_request.")
- # Ensure length consistency
- length_set = {req.inputs.shape[1] for req in reqs}
- if len(length_set) != 1:
- raise ValueError(
- f"All requests must have the same length, " f"but got
{length_set}"
- )
+ # Ensure shape consistency
+ first_target_count = reqs[0].target_count
+ first_input_length = reqs[0].input_length
+
+ for i, req in enumerate(reqs):
+ if req.target_count != first_target_count:
+ raise ValueError(
+ f"All requests must have the same target_count, "
+ f"but request 0 has {first_target_count} "
+ f"and request {i} has {req.target_count}"
+ )
+
+ if req.input_length != first_input_length:
+ raise ValueError(
+ f"All requests must have the same input_length, "
+ f"but request 0 has {first_input_length} "
+ f"and request {i} has {req.input_length}"
+ )
batch_inputs = torch.cat([req.inputs for req in reqs], dim=0)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
index 93887477aa5..43380aa1a08 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
@@ -49,21 +49,20 @@ class InferenceRequest:
self.model_id = model_id
self.inputs = inputs
self.infer_kwargs = infer_kwargs
- self.output_length = (
- output_length # Number of time series data points to generate
- )
+ self.output_length = output_length
self.batch_size = inputs.size(0)
- self.variable_size = inputs.size(1)
+ self.target_count = inputs.size(1)
+ self.input_length = inputs.size(2)
self.state = InferenceRequestState.WAITING
self.cur_step_idx = 0 # Current write position in the output step
index
self.assigned_pool_id = -1 # The pool handling this request
self.assigned_device_id = -1 # The device handling this request
- # Preallocate output buffer [batch_size, max_new_tokens]
+ # Preallocate output buffer [batch_size, target_count, output_length]
self.output_tensor = torch.zeros(
- self.batch_size, self.variable_size, output_length, device="cpu"
- ) # shape: [batch_size, target_count, predict_length]
+ self.batch_size, self.target_count, output_length, device="cpu"
+ )
def mark_running(self):
self.state = InferenceRequestState.RUNNING
@@ -81,7 +80,7 @@ class InferenceRequest:
while step_output.ndim < 3:
step_output = step_output.unsqueeze(0)
- batch_size, variable_size, step_size = step_output.shape
+ batch_size, target_count, step_size = step_output.shape
end_idx = self.cur_step_idx + step_size
if end_idx > self.output_length:
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 516c1d07c2c..dcfa4528fce 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -115,7 +115,7 @@ class InferenceRequestPool(mp.Process):
grouped_requests = defaultdict(list)
for req in all_requests:
- key = (req.inputs.shape[1], req.output_length)
+ key = (req.target_count, req.input_length, req.output_length)
grouped_requests[key].append(req)
grouped_requests = list(grouped_requests.values())
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 65aa7714393..591577785fd 100644
---
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -20,6 +20,7 @@ from typing import Dict, List, Optional
import torch
+from iotdb.ainode.core.exception import ModelNotExistException
from iotdb.ainode.core.inference.pool_group import PoolGroup
from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import
(
AbstractPoolScheduler,
@@ -51,6 +52,14 @@ def _estimate_shared_pool_size_by_total_mem(
Returns:
mapping {model_id: pool_num}
"""
+
+ # Check if the model supports concurrent forecasting
+ if new_model_info and new_model_info.model_id not in MODEL_MEM_USAGE_MAP:
+ logger.error(
+ f"[Inference] Cannot estimate inference pool size on device:
{device}, because model: {new_model_info.model_id} does not support concurrent
forecasting."
+ )
+ raise ModelNotExistException(new_model_info.model_id)
+
# Extract unique model IDs
all_models = existing_model_infos + (
[new_model_info] if new_model_info is not None else []
@@ -60,7 +69,7 @@ def _estimate_shared_pool_size_by_total_mem(
mem_usages: Dict[str, float] = {}
for model_info in all_models:
mem_usages[model_info.model_id] = (
- MODEL_MEM_USAGE_MAP[model_info.model_type] *
INFERENCE_EXTRA_MEMORY_RATIO
+ MODEL_MEM_USAGE_MAP[model_info.model_id] *
INFERENCE_EXTRA_MEMORY_RATIO
)
# Evaluate system resources and get TOTAL memory
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index addcfad6cfb..ebbb036a9dc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -253,7 +253,7 @@ class InferenceManager:
return resp_cls(
get_status(TSStatusCode.SUCCESS_STATUS),
- output_list[0] if single_batch else output_list,
+ [output_list[0]] if single_batch else output_list,
)
except Exception as e:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
index 45db66e0186..892d4650e15 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
@@ -67,9 +67,7 @@ def estimate_pool_size(device: torch.device, model_id: str)
-> int:
system_res = evaluate_system_resources(device)
free_mem = system_res["free_mem"]
- mem_usage = (
- MODEL_MEM_USAGE_MAP[model_info.model_type] *
INFERENCE_EXTRA_MEMORY_RATIO
- )
+ mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] *
INFERENCE_EXTRA_MEMORY_RATIO
size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage)
if size <= 0:
logger.error(
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
index 670e019a4b6..34a1a6b2239 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
@@ -377,7 +377,7 @@ public class ClassifyTableFunction implements TableFunction
{
"Error occurred while executing classify:[%s]",
resp.getStatus().getMessage());
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}
- return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+ return SERDE.deserialize(resp.forecastResult.get(0));
}
}
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index dcb27825e31..579802c542a 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -574,7 +574,7 @@ public class ForecastTableFunction implements TableFunction
{
throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
}
- TsBlock res =
SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+ TsBlock res = SERDE.deserialize(resp.forecastResult.get(0));
if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
throw new IoTDBRuntimeException(
String.format(
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index ebecf79f5b7..a6794a5896f 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -40,7 +40,6 @@ import org.apache.tsfile.read.common.block.TsBlockBuilder;
import org.apache.tsfile.read.common.block.column.TsBlockSerde;
import java.io.IOException;
-import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -226,7 +225,7 @@ public class UDTFForecast implements UDTF {
resp.getStatus().getCode(), resp.getStatus().getMessage()),
resp.getStatus().getCode());
}
- return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+ return serde.deserialize(resp.forecastResult.get(0));
}
@Override
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 1cb585f0323..68347b89203 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -92,7 +92,7 @@ struct TForecastReq {
struct TForecastResp {
1: required common.TSStatus status
- 2: optional binary forecastResult
+ 2: optional list<binary> forecastResult
}
struct TShowModelsReq {