This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 333b385a08a [AINode] Fix bugs and sync code (#17032)
333b385a08a is described below

commit 333b385a08a0c341dd33cd19bff7f0c76db66228
Author: Leo <[email protected]>
AuthorDate: Sun Jan 18 17:42:51 2026 +0800

    [AINode] Fix bugs and sync code (#17032)
---
 iotdb-core/ainode/iotdb/ainode/core/constant.py    |  2 +-
 .../ainode/core/inference/batcher/basic_batcher.py | 27 ++++++++++++++++------
 .../ainode/core/inference/inference_request.py     | 15 ++++++------
 .../core/inference/inference_request_pool.py       |  2 +-
 .../pool_scheduler/basic_pool_scheduler.py         | 11 ++++++++-
 .../iotdb/ainode/core/manager/inference_manager.py |  2 +-
 .../ainode/iotdb/ainode/core/manager/utils.py      |  4 +---
 .../function/tvf/ClassifyTableFunction.java        |  2 +-
 .../function/tvf/ForecastTableFunction.java        |  2 +-
 .../db/queryengine/plan/udf/UDTFForecast.java      |  3 +--
 .../thrift-ainode/src/main/thrift/ainode.thrift    |  2 +-
 11 files changed, 45 insertions(+), 27 deletions(-)

diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py 
b/iotdb-core/ainode/iotdb/ainode/core/constant.py
index 8a83c981437..68f64a79afc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/constant.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py
@@ -53,7 +53,7 @@ AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880
 # TODO: Should be optimized
 AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
     "sundial": 1036 * 1024**2,  # 1036 MiB
-    "timer": 856 * 1024**2,  # 856 MiB
+    "timer_xl": 856 * 1024**2,  # 856 MiB
 }  # the memory usage of each model in bytes
 
 AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.2  # the device space allocated for 
inference
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
index 591a0d7c1dd..80ff683d9d1 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/batcher/basic_batcher.py
@@ -34,6 +34,7 @@ class BasicBatcher(AbstractBatcher):
         Args:
 
         """
+        super().__init__()
 
     def batch_request(self, reqs: List[InferenceRequest]) -> torch.Tensor:
         """
@@ -46,17 +47,29 @@ class BasicBatcher(AbstractBatcher):
 
         Returns:
             torch.Tensor: Concatenated input tensor of shape
-                          [sum(req.batch_size), length].
+                          [sum(req.batch_size), target_count, input_length].
         """
         if not reqs:
             raise ValueError("No requests provided to batch_request.")
 
-        # Ensure length consistency
-        length_set = {req.inputs.shape[1] for req in reqs}
-        if len(length_set) != 1:
-            raise ValueError(
-                f"All requests must have the same length, " f"but got 
{length_set}"
-            )
+        # Ensure shape consistency
+        first_target_count = reqs[0].target_count
+        first_input_length = reqs[0].input_length
+
+        for i, req in enumerate(reqs):
+            if req.target_count != first_target_count:
+                raise ValueError(
+                    f"All requests must have the same target_count, "
+                    f"but request 0 has {first_target_count} "
+                    f"and request {i} has {req.target_count}"
+                )
+
+            if req.input_length != first_input_length:
+                raise ValueError(
+                    f"All requests must have the same input_length, "
+                    f"but request 0 has {first_input_length} "
+                    f"and request {i} has {req.input_length}"
+                )
 
         batch_inputs = torch.cat([req.inputs for req in reqs], dim=0)
 
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
index 93887477aa5..43380aa1a08 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
@@ -49,21 +49,20 @@ class InferenceRequest:
         self.model_id = model_id
         self.inputs = inputs
         self.infer_kwargs = infer_kwargs
-        self.output_length = (
-            output_length  # Number of time series data points to generate
-        )
+        self.output_length = output_length
 
         self.batch_size = inputs.size(0)
-        self.variable_size = inputs.size(1)
+        self.target_count = inputs.size(1)
+        self.input_length = inputs.size(2)
         self.state = InferenceRequestState.WAITING
         self.cur_step_idx = 0  # Current write position in the output step 
index
         self.assigned_pool_id = -1  # The pool handling this request
         self.assigned_device_id = -1  # The device handling this request
 
-        # Preallocate output buffer [batch_size, max_new_tokens]
+        # Preallocate output buffer [batch_size, target_count, output_length]
         self.output_tensor = torch.zeros(
-            self.batch_size, self.variable_size, output_length, device="cpu"
-        )  # shape: [batch_size, target_count, predict_length]
+            self.batch_size, self.target_count, output_length, device="cpu"
+        )
 
     def mark_running(self):
         self.state = InferenceRequestState.RUNNING
@@ -81,7 +80,7 @@ class InferenceRequest:
         while step_output.ndim < 3:
             step_output = step_output.unsqueeze(0)
 
-        batch_size, variable_size, step_size = step_output.shape
+        batch_size, target_count, step_size = step_output.shape
         end_idx = self.cur_step_idx + step_size
 
         if end_idx > self.output_length:
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py 
b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 516c1d07c2c..dcfa4528fce 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -115,7 +115,7 @@ class InferenceRequestPool(mp.Process):
 
         grouped_requests = defaultdict(list)
         for req in all_requests:
-            key = (req.inputs.shape[1], req.output_length)
+            key = (req.target_count, req.input_length, req.output_length)
             grouped_requests[key].append(req)
         grouped_requests = list(grouped_requests.values())
 
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 65aa7714393..591577785fd 100644
--- 
a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++ 
b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -20,6 +20,7 @@ from typing import Dict, List, Optional
 
 import torch
 
+from iotdb.ainode.core.exception import ModelNotExistException
 from iotdb.ainode.core.inference.pool_group import PoolGroup
 from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import 
(
     AbstractPoolScheduler,
@@ -51,6 +52,14 @@ def _estimate_shared_pool_size_by_total_mem(
     Returns:
         mapping {model_id: pool_num}
     """
+
+    # Check if the model supports concurrent forecasting
+    if new_model_info and new_model_info.model_id not in MODEL_MEM_USAGE_MAP:
+        logger.error(
+            f"[Inference] Cannot estimate inference pool size on device: 
{device}, because model: {new_model_info.model_id} does not support concurrent 
forecasting."
+        )
+        raise ModelNotExistException(new_model_info.model_id)
+
     # Extract unique model IDs
     all_models = existing_model_infos + (
         [new_model_info] if new_model_info is not None else []
@@ -60,7 +69,7 @@ def _estimate_shared_pool_size_by_total_mem(
     mem_usages: Dict[str, float] = {}
     for model_info in all_models:
         mem_usages[model_info.model_id] = (
-            MODEL_MEM_USAGE_MAP[model_info.model_type] * 
INFERENCE_EXTRA_MEMORY_RATIO
+            MODEL_MEM_USAGE_MAP[model_info.model_id] * 
INFERENCE_EXTRA_MEMORY_RATIO
         )
 
     # Evaluate system resources and get TOTAL memory
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index addcfad6cfb..ebbb036a9dc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -253,7 +253,7 @@ class InferenceManager:
 
             return resp_cls(
                 get_status(TSStatusCode.SUCCESS_STATUS),
-                output_list[0] if single_batch else output_list,
+                [output_list[0]] if single_batch else output_list,
             )
 
         except Exception as e:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py 
b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
index 45db66e0186..892d4650e15 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
@@ -67,9 +67,7 @@ def estimate_pool_size(device: torch.device, model_id: str) 
-> int:
     system_res = evaluate_system_resources(device)
     free_mem = system_res["free_mem"]
 
-    mem_usage = (
-        MODEL_MEM_USAGE_MAP[model_info.model_type] * 
INFERENCE_EXTRA_MEMORY_RATIO
-    )
+    mem_usage = MODEL_MEM_USAGE_MAP[model_info.model_id] * 
INFERENCE_EXTRA_MEMORY_RATIO
     size = int((free_mem * INFERENCE_MEMORY_USAGE_RATIO) // mem_usage)
     if size <= 0:
         logger.error(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
index 670e019a4b6..34a1a6b2239 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ClassifyTableFunction.java
@@ -377,7 +377,7 @@ public class ClassifyTableFunction implements TableFunction 
{
                 "Error occurred while executing classify:[%s]", 
resp.getStatus().getMessage());
         throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
       }
-      return SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+      return SERDE.deserialize(resp.forecastResult.get(0));
     }
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index dcb27825e31..579802c542a 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -574,7 +574,7 @@ public class ForecastTableFunction implements TableFunction 
{
         throw new IoTDBRuntimeException(message, resp.getStatus().getCode());
       }
 
-      TsBlock res = 
SERDE.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+      TsBlock res = SERDE.deserialize(resp.forecastResult.get(0));
       if (res.getValueColumnCount() != inputData.getValueColumnCount()) {
         throw new IoTDBRuntimeException(
             String.format(
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
index ebecf79f5b7..a6794a5896f 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java
@@ -40,7 +40,6 @@ import org.apache.tsfile.read.common.block.TsBlockBuilder;
 import org.apache.tsfile.read.common.block.column.TsBlockSerde;
 
 import java.io.IOException;
-import java.nio.ByteBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashSet;
@@ -226,7 +225,7 @@ public class UDTFForecast implements UDTF {
               resp.getStatus().getCode(), resp.getStatus().getMessage()),
           resp.getStatus().getCode());
     }
-    return serde.deserialize(ByteBuffer.wrap(resp.getForecastResult()));
+    return serde.deserialize(resp.forecastResult.get(0));
   }
 
   @Override
diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift 
b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
index 1cb585f0323..68347b89203 100644
--- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
+++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift
@@ -92,7 +92,7 @@ struct TForecastReq {
 
 struct TForecastResp {
   1: required common.TSStatus status
-  2: optional binary forecastResult
+  2: optional list<binary> forecastResult
 }
 
 struct TShowModelsReq {

Reply via email to