This is an automated email from the ASF dual-hosted git repository.

ycycse pushed a commit to branch timer_xl_inference
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/timer_xl_inference by this 
push:
     new 16ecc769fa8 fix interface
16ecc769fa8 is described below

commit 16ecc769fa8ac97a461d6890020b05fee4770963
Author: YangCaiyin <[email protected]>
AuthorDate: Fri May 9 22:37:42 2025 +0800

    fix interface
---
 .../ainode/core/manager/inference_manager.py       | 50 ++++++----------------
 1 file changed, 14 insertions(+), 36 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 476b8d68b80..3ffc608847b 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -63,7 +63,7 @@ class InferenceManager:
         model_id = req.modelId
         logger.info(f"start to forcast by model {model_id}")
         try:
-            data = deserialize(req.inputData)
+            full_data = deserialize(req.inputData)
             if model_id.startswith('_'):
                 # built-in models
                 logger.info(f"start to forecast built-in model {model_id}")
@@ -71,17 +71,18 @@ class InferenceManager:
                 options = req.options
                 options['predict_length'] = req.outputLength
                 model = _get_built_in_model(model_id, model_manager, options)
+
                 if model_id == '_timerxl':
                     inference_result = _inference_with_timerxl(
-                        model, data, options.get("predict_length", 96))
+                        model, full_data, options.get("predict_length", 96))
                 else:
                     inference_result =_inference_with_built_in_model(
-                        model, data)
+                        model, full_data)
                 inference_result = convert_to_binary(inference_result)
             else:
                 # user-registered models
                 model = _get_model(model_id, model_manager, req.options)
-                _, dataset, _, dataset_length = data
+                _, dataset, _, dataset_length = full_data
                 dataset = torch.tensor(dataset, dtype=torch.float).unsqueeze(2)
                 inference_results = _inference_with_registered_model(
                     model, dataset, dataset_length, dataset_length, 
float('inf'))
@@ -108,7 +109,7 @@ class InferenceManager:
                 model = _get_built_in_model(model_id, model_manager, 
inference_attributes)
                 if model_id == '_timerxl':
                     inference_results = [_inference_with_timerxl(
-                        model, full_data, 
inference_attributes.get("predict_length", 96))]
+                        model, full_data, 
int(inference_attributes.get("predict_length", 96)))]
                 else:
                     inference_results = [_inference_with_built_in_model(
                         model, full_data)]
@@ -203,34 +204,17 @@ def _inference_with_built_in_model(model, full_data):
         is the output window interval, C' is the number of variables in the 
output DataFrame. Then the inference module
         will concatenate all the output DataFrames into a list.
     """
-
-    _, data, _, _ = full_data
-    output = model.inference(data[0])
+    data = full_data[1][0]
+    if data.dtype.byteorder not in ('=', '|'):
+        data = data.byteswap().newbyteorder()
+    output = model.inference(data)
     # output: DataFrame, shape: (H', C')
     output = pd.DataFrame(output)
     return output
 
 def _inference_with_timerxl(model, full_data, pred_len):
-    """
-    Args:
-        model: the built-in model
-        full_data: a tuple of (data, time_stamp, type_list, column_name_list), 
where the data is a DataFrame with shape
-            (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a 
list of data types with length C,
-            column_name_list is a list of column names with length C, where L 
is the number of data points, C is the
-            number of variables, the data and time_stamp are aligned by index
-    Returns:
-        outputs: a list of output DataFrames, where each DataFrame has shape 
(H', C'), where H' is the output window
-            interval, C' is the number of variables in the output DataFrame
-    Description:
-        the inference_with_built_in_model function will inference with 
built-in model, which does not
-        require user registration. This module will parse the inference 
attributes and create the built-in model, then
-        feed the input data into the model to get the output, the output is a 
DataFrame with shape (H', C'), where H'
-        is the output window interval, C' is the number of variables in the 
output DataFrame. Then the inference module
-        will concatenate all the output DataFrames into a list.
-    """
 
-    _, data, _, _ = full_data
-    data = data[0]
+    data = full_data[1][0]
     if data.dtype.byteorder not in ('=', '|'):
         data = data.byteswap().newbyteorder()
     output = model.inference(data, pred_len)
@@ -254,13 +238,7 @@ def _get_built_in_model(model_id: str, model_manager: 
ModelManager, inference_at
 
 
 def _parse_inference_request(req: TInferenceReq):
-    binary_dataset = req.dataset
-    type_list = req.typeList
-    column_name_list = req.columnNameList
-    column_name_index = req.columnNameIndexMap
-    data = convert_to_df(column_name_list, type_list, column_name_index, 
[binary_dataset])
-    time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
-    full_data = (data, time_stamp, type_list, column_name_list)
+    data = deserialize(req.dataset)
     inference_attributes = req.inferenceAttributes
     if inference_attributes is None:
         inference_attributes = {}
@@ -269,8 +247,8 @@ def _parse_inference_request(req: TInferenceReq):
     if window_params is None:
         # set default window_step to infinity and window_interval to dataset 
length
         window_step = float('inf')
-        window_interval = data.shape[0]
+        window_interval = len(data[1][0])
     else:
         window_step = window_params.windowStep
         window_interval = window_params.windowInterval
-    return req.modelId, full_data, window_interval, window_step, 
inference_attributes
+    return req.modelId, data, window_interval, window_step, 
inference_attributes

Reply via email to