This is an automated email from the ASF dual-hosted git repository.
ycycse pushed a commit to branch timer_xl_inference
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/timer_xl_inference by this
push:
new 16ecc769fa8 fix interface
16ecc769fa8 is described below
commit 16ecc769fa8ac97a461d6890020b05fee4770963
Author: YangCaiyin <[email protected]>
AuthorDate: Fri May 9 22:37:42 2025 +0800
fix interface
---
.../ainode/core/manager/inference_manager.py | 50 ++++++----------------
1 file changed, 14 insertions(+), 36 deletions(-)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 476b8d68b80..3ffc608847b 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -63,7 +63,7 @@ class InferenceManager:
model_id = req.modelId
logger.info(f"start to forcast by model {model_id}")
try:
- data = deserialize(req.inputData)
+ full_data = deserialize(req.inputData)
if model_id.startswith('_'):
# built-in models
logger.info(f"start to forecast built-in model {model_id}")
@@ -71,17 +71,18 @@ class InferenceManager:
options = req.options
options['predict_length'] = req.outputLength
model = _get_built_in_model(model_id, model_manager, options)
+
if model_id == '_timerxl':
inference_result = _inference_with_timerxl(
- model, data, options.get("predict_length", 96))
+ model, full_data, options.get("predict_length", 96))
else:
inference_result =_inference_with_built_in_model(
- model, data)
+ model, full_data)
inference_result = convert_to_binary(inference_result)
else:
# user-registered models
model = _get_model(model_id, model_manager, req.options)
- _, dataset, _, dataset_length = data
+ _, dataset, _, dataset_length = full_data
dataset = torch.tensor(dataset, dtype=torch.float).unsqueeze(2)
inference_results = _inference_with_registered_model(
model, dataset, dataset_length, dataset_length,
float('inf'))
@@ -108,7 +109,7 @@ class InferenceManager:
model = _get_built_in_model(model_id, model_manager,
inference_attributes)
if model_id == '_timerxl':
inference_results = [_inference_with_timerxl(
- model, full_data,
inference_attributes.get("predict_length", 96))]
+ model, full_data,
int(inference_attributes.get("predict_length", 96)))]
else:
inference_results = [_inference_with_built_in_model(
model, full_data)]
@@ -203,34 +204,17 @@ def _inference_with_built_in_model(model, full_data):
is the output window interval, C' is the number of variables in the
output DataFrame. Then the inference module
will concatenate all the output DataFrames into a list.
"""
-
- _, data, _, _ = full_data
- output = model.inference(data[0])
+ data = full_data[1][0]
+ if data.dtype.byteorder not in ('=', '|'):
+ data = data.byteswap().newbyteorder()
+ output = model.inference(data)
# output: DataFrame, shape: (H', C')
output = pd.DataFrame(output)
return output
def _inference_with_timerxl(model, full_data, pred_len):
- """
- Args:
- model: the built-in model
- full_data: a tuple of (data, time_stamp, type_list, column_name_list),
where the data is a DataFrame with shape
- (L, C), time_stamp is a DataFrame with shape(L, 1), type_list is a
list of data types with length C,
- column_name_list is a list of column names with length C, where L
is the number of data points, C is the
- number of variables, the data and time_stamp are aligned by index
- Returns:
- outputs: a list of output DataFrames, where each DataFrame has shape
(H', C'), where H' is the output window
- interval, C' is the number of variables in the output DataFrame
- Description:
- the inference_with_built_in_model function will inference with
built-in model, which does not
- require user registration. This module will parse the inference
attributes and create the built-in model, then
- feed the input data into the model to get the output, the output is a
DataFrame with shape (H', C'), where H'
- is the output window interval, C' is the number of variables in the
output DataFrame. Then the inference module
- will concatenate all the output DataFrames into a list.
- """
- _, data, _, _ = full_data
- data = data[0]
+ data = full_data[1][0]
if data.dtype.byteorder not in ('=', '|'):
data = data.byteswap().newbyteorder()
output = model.inference(data, pred_len)
@@ -254,13 +238,7 @@ def _get_built_in_model(model_id: str, model_manager:
ModelManager, inference_at
def _parse_inference_request(req: TInferenceReq):
- binary_dataset = req.dataset
- type_list = req.typeList
- column_name_list = req.columnNameList
- column_name_index = req.columnNameIndexMap
- data = convert_to_df(column_name_list, type_list, column_name_index,
[binary_dataset])
- time_stamp, data = data[data.columns[0:1]], data[data.columns[1:]]
- full_data = (data, time_stamp, type_list, column_name_list)
+ data = deserialize(req.dataset)
inference_attributes = req.inferenceAttributes
if inference_attributes is None:
inference_attributes = {}
@@ -269,8 +247,8 @@ def _parse_inference_request(req: TInferenceReq):
if window_params is None:
# set default window_step to infinity and window_interval to dataset
length
window_step = float('inf')
- window_interval = data.shape[0]
+ window_interval = len(data[1][0])
else:
window_step = window_params.windowStep
window_interval = window_params.windowInterval
- return req.modelId, full_data, window_interval, window_step,
inference_attributes
+ return req.modelId, data, window_interval, window_step,
inference_attributes