This is an automated email from the ASF dual-hosted git repository.
hui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 0750fa6f49 [IOTDB-5728] Implement config parser & model/dataset
factory on MLNode (#9458)
0750fa6f49 is described below
commit 0750fa6f492b052fea5c01e88a21032a12631488
Author: lichenyu <[email protected]>
AuthorDate: Fri Mar 31 15:35:26 2023 +0800
[IOTDB-5728] Implement config parser & model/dataset factory on MLNode
(#9458)
Co-authored-by: Wenwei <[email protected]>
---
.../{models/forecast/__init__.py => enums.py} | 12 ++
mlnode/iotdb/mlnode/algorithm/factory.py | 128 ++++++++++++++
.../mlnode/algorithm/models/forecast/__init__.py | 3 +
.../mlnode/algorithm/models/forecast/dlinear.py | 41 ++++-
.../mlnode/algorithm/models/forecast/nbeats.py | 47 ++++-
mlnode/iotdb/mlnode/client.py | 9 +-
mlnode/iotdb/mlnode/constant.py | 1 +
.../{datats/utils => data_access}/__init__.py | 0
.../forecast/__init__.py => data_access/enums.py} | 12 ++
mlnode/iotdb/mlnode/data_access/factory.py | 105 +++++++++++
.../forecast => data_access/offline}/__init__.py | 0
.../{datats => data_access}/offline/dataset.py | 30 +---
.../offline/source.py} | 9 +-
.../forecast => data_access/utils}/__init__.py | 0
.../{datats => data_access}/utils/timefeatures.py | 2 -
mlnode/iotdb/mlnode/exception.py | 16 +-
mlnode/iotdb/mlnode/handler.py | 27 ++-
mlnode/iotdb/mlnode/parser.py | 194 +++++++++++++++++++++
mlnode/iotdb/mlnode/serde.py | 30 +++-
mlnode/iotdb/mlnode/util.py | 4 +-
mlnode/test/test_parse_training_request.py | 136 +++++++++++++++
21 files changed, 747 insertions(+), 59 deletions(-)
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
b/mlnode/iotdb/mlnode/algorithm/enums.py
similarity index 76%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/algorithm/enums.py
index 2a1e720805..4b05aa4bf8 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/algorithm/enums.py
@@ -15,3 +15,15 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
+
+
+class ForecastTaskType(Enum):
+ ENDOGENOUS = "endogenous"
+ EXOGENOUS = "exogenous"
+
+ def __str__(self):
+ return self.value
+
+ def __eq__(self, other: str) -> bool:
+ return self.value == other
diff --git a/mlnode/iotdb/mlnode/algorithm/factory.py
b/mlnode/iotdb/mlnode/algorithm/factory.py
new file mode 100644
index 0000000000..92cb01a883
--- /dev/null
+++ b/mlnode/iotdb/mlnode/algorithm/factory.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import torch.nn as nn
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.algorithm.models.forecast import support_forecasting_models
+from iotdb.mlnode.exception import BadConfigValueError
+
+
+# Common configs for all forecasting model with default values
+def _common_config(**kwargs):
+ return {
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'input_vars': 1,
+ 'output_vars': 1,
+ **kwargs
+ }
+
+
+# Common forecasting task configs
+support_common_configs = {
+ # multivariate forecasting, current support this only
+ ForecastTaskType.ENDOGENOUS: _common_config(
+ input_vars=1,
+ output_vars=1),
+
+ # univariate forecasting with observable exogenous variables
+ ForecastTaskType.EXOGENOUS: _common_config(
+ output_vars=1),
+}
+
+
+def is_model(model_name: str) -> bool:
+ """
+ Check if a model name exists
+ """
+ return model_name in support_forecasting_models
+
+
+def list_model() -> list[str]:
+ """
+ List support forecasting model
+ """
+ return support_forecasting_models
+
+
+def create_forecast_model(
+ model_name,
+ forecast_task_type=ForecastTaskType.ENDOGENOUS,
+ input_len=96,
+ pred_len=96,
+ input_vars=1,
+ output_vars=1,
+ **kwargs,
+) -> [nn.Module, dict]:
+ """
+ Factory method for all support forecasting models
+ the given arguments is common configs shared by all forecasting models
+ for specific model configs, see __model_config in
`algorithm/models/MODELNAME.py`
+
+ Args:
+ model_name: see available models by `list_model`
+ forecast_task_type: 'm' for multivariate forecasting, 'ms' for
covariate forecasting,
+ 's' for univariate forecasting
+ input_len: time length of model input
+ pred_len: time length of model output
+ input_vars: number of input series
+ output_vars: number of output series
+ kwargs: for specific model configs, see returned `model_config` with
kwargs=None
+
+ Returns:
+ model: torch.nn.Module
+ model_config: dict of model configurations
+ """
+ if not is_model(model_name):
+ raise BadConfigValueError('model_name', model_name, f'It should be one
of {list_model()}')
+ if forecast_task_type not in support_common_configs.keys():
+ raise BadConfigValueError('forecast_task_type', forecast_task_type,
+ f'It should be one of
{list(support_common_configs.keys())}')
+
+ common_config = support_common_configs[forecast_task_type]
+ common_config['input_len'] = input_len
+ common_config['pred_len'] = pred_len
+ common_config['input_vars'] = input_vars
+ common_config['output_vars'] = output_vars
+ common_config['forecast_task_type'] = str(forecast_task_type)
+
+ if not input_len > 0:
+ raise BadConfigValueError('input_len', input_len,
+ 'Length of input series should be positive')
+ if not pred_len > 0:
+ raise BadConfigValueError('pred_len', pred_len,
+ 'Length of predicted series should be
positive')
+ if not input_vars > 0:
+ raise BadConfigValueError('input_vars', input_vars,
+ 'Number of input variates should be
positive')
+ if not output_vars > 0:
+ raise BadConfigValueError('output_vars', output_vars,
+ 'Number of output variates should be
positive')
+ if forecast_task_type == ForecastTaskType.ENDOGENOUS:
+ if input_vars != output_vars:
+ raise BadConfigValueError('forecast_task_type', forecast_task_type,
+ 'Number of input/output variates should
be '
+ 'the same in multivariate forecast')
+ create_fn = eval(model_name)
+ model, model_config = create_fn(
+ common_config=common_config,
+ **kwargs
+ )
+ model_config['model_name'] = model_name
+
+ return model, model_config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
index 2a1e720805..2abb5faf37 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
@@ -15,3 +15,6 @@
# specific language governing permissions and limitations
# under the License.
#
+
+
+support_forecasting_models = ['dlinear', 'dlinear_individual', 'nbeats']
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
index 58fb12bf29..fa9ee04e56 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/dlinear.py
@@ -15,12 +15,14 @@
# specific language governing permissions and limitations
# under the License.
#
-import argparse
+
import math
import torch
import torch.nn as nn
+from iotdb.mlnode.exception import BadConfigValueError
+
class MovingAverageBlock(nn.Module):
""" Moving average block to highlight the trend of time series """
@@ -61,7 +63,9 @@ class DLinear(nn.Module):
kernel_size=25,
input_len=96,
pred_len=96,
- input_vars=1
+ input_vars=1,
+ output_vars=1,
+ forecast_type='m', # TODO, support others
):
super(DLinear, self).__init__()
self.input_len = input_len
@@ -94,7 +98,9 @@ class DLinearIndividual(nn.Module):
kernel_size=25,
input_len=96,
pred_len=96,
- input_vars=1
+ input_vars=1,
+ output_vars=1,
+ forecast_type='m', # TODO, support others
):
super(DLinearIndividual, self).__init__()
self.input_len = input_len
@@ -128,11 +134,28 @@ class DLinearIndividual(nn.Module):
return x.permute(0, 2, 1) # to [Batch, Output length, Channel]
-def dlinear(model_config: argparse.Namespace) -> DLinear:
- # TODO (@lcy)
- pass
+def _model_config(**kwargs):
+ return {
+ 'kernel_size': 25,
+ **kwargs
+ }
+
+
+def dlinear(common_config: dict, kernel_size=25, **kwargs) -> [DLinear, dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not kernel_size > 0:
+ raise BadConfigValueError('kernel_size', kernel_size,
+ 'Kernel size of dlinear should larger than
0')
+ config['kernel_size'] = kernel_size
+ return DLinear(**config), config
-def dlinear_individual(model_config: argparse.Namespace) -> DLinearIndividual:
- # TODO (@lcy)
- pass
+def dlinear_individual(common_config: dict, kernel_size=25, **kwargs) ->
[DLinearIndividual, dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not kernel_size > 0:
+ raise BadConfigValueError('kernel_size', kernel_size,
+ 'Kernel size of dlinear_individual should
larger than 0')
+ config['kernel_size'] = kernel_size
+ return DLinearIndividual(**config), config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
index 0744cd4460..e3c3ca6a0a 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
+++ b/mlnode/iotdb/mlnode/algorithm/models/forecast/nbeats.py
@@ -16,12 +16,13 @@
# under the License.
#
-import argparse
from typing import Tuple
import torch
import torch.nn as nn
+from iotdb.mlnode.exception import BadConfigValueError
+
class GenericBasis(nn.Module):
""" Generic basis function """
@@ -37,10 +38,6 @@ class GenericBasis(nn.Module):
block_dict = {
'generic': GenericBasis,
-
- # TODO(@lcy) support more block type
- # 'trend': TrendBasis,
- # 'seasonality': SeasonalityBasis,
}
@@ -109,6 +106,8 @@ class NBeats(nn.Module):
input_len=96,
pred_len=96,
input_vars=1,
+ output_vars=1,
+ forecast_type='m', # TODO, support others
):
super(NBeats, self).__init__()
self.enc_in = input_vars
@@ -133,6 +132,38 @@ class NBeats(nn.Module):
return torch.stack(res, dim=-1) # to [Batch, Output length, Channel]
-def nbeats(model_config: argparse.Namespace) -> NBeats:
- # TODO (@lcy)
- pass
+def _model_config(**kwargs):
+ return {
+ 'block_type': 'generic',
+ 'd_model': 128,
+ 'inner_layers': 4,
+ 'outer_layers': 4,
+ **kwargs
+ }
+
+
+"""
+Specific configs for NBeats variants
+"""
+support_model_configs = {
+ 'nbeats': _model_config(
+ block_type='generic'),
+}
+
+
+def nbeats(common_config: dict, d_model=128, inner_layers=4, outer_layers=4,
**kwargs) -> [NBeats, dict]:
+ config = _model_config()
+ config.update(**common_config)
+ if not d_model > 0:
+ raise BadConfigValueError('d_model', d_model,
+ 'Model dimension (d_model) of nbeats should
larger than 0')
+ if not inner_layers > 0:
+ raise BadConfigValueError('inner_layers', inner_layers,
+ 'Number of inner layers of nbeats should
larger than 0')
+ if not outer_layers > 0:
+ raise BadConfigValueError('outer_layers', outer_layers,
+ 'Number of outer layers of nbeats should
larger than 0')
+ config['d_model'] = d_model
+ config['inner_layers'] = inner_layers
+ config['outer_layers'] = outer_layers
+ return NBeats(**config), config
diff --git a/mlnode/iotdb/mlnode/client.py b/mlnode/iotdb/mlnode/client.py
index aa1536e130..76eb754596 100644
--- a/mlnode/iotdb/mlnode/client.py
+++ b/mlnode/iotdb/mlnode/client.py
@@ -70,7 +70,7 @@ class MLNodeClient(object):
model_id: str,
is_auto: bool,
model_configs: dict,
- query_expressions: list[str],
+ query_expressions: list = [],
query_filter: str = None) -> None:
req = TCreateTrainingTaskReq(
modelId=model_id,
@@ -116,6 +116,7 @@ class DataNodeClient(object):
transport.open()
except TTransport.TTransportException as e:
logger.exception("TTransportException!", exc_info=e)
+ raise e
protocol = TBinaryProtocol.TBinaryProtocol(transport)
self.__client = IDataNodeRPCService.Client(protocol)
@@ -123,7 +124,7 @@ class DataNodeClient(object):
def fetch_timeseries(self,
session_id: int,
statement_id: int,
- query_expressions: list[str],
+ query_expressions: list = [],
query_filter: str = None,
fetch_size: int = DEFAULT_FETCH_SIZE,
timeout: int = DEFAULT_TIMEOUT) ->
TFetchTimeseriesResp:
@@ -145,8 +146,8 @@ class DataNodeClient(object):
def record_model_metrics(self,
model_id: str,
trial_id: str,
- metrics: list[str],
- values: list[float]) -> None:
+ metrics: list = [],
+ values: list = []) -> None:
req = TRecordModelMetricsReq(
modelId=model_id,
trialId=trial_id,
diff --git a/mlnode/iotdb/mlnode/constant.py b/mlnode/iotdb/mlnode/constant.py
index 810d7c261e..3bffa06526 100644
--- a/mlnode/iotdb/mlnode/constant.py
+++ b/mlnode/iotdb/mlnode/constant.py
@@ -27,6 +27,7 @@ MLNODE_MODEL_STORAGE_DIRECTORY_NAME = "models"
class TSStatusCode(Enum):
SUCCESS_STATUS = 200
REDIRECTION_RECOMMEND = 400
+ FAIL_STATUS = 404
def get_status_code(self) -> int:
return self.value
diff --git a/mlnode/iotdb/mlnode/datats/utils/__init__.py
b/mlnode/iotdb/mlnode/data_access/__init__.py
similarity index 100%
rename from mlnode/iotdb/mlnode/datats/utils/__init__.py
rename to mlnode/iotdb/mlnode/data_access/__init__.py
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
b/mlnode/iotdb/mlnode/data_access/enums.py
similarity index 77%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/enums.py
index 2a1e720805..d21a9f69c4 100644
--- a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
+++ b/mlnode/iotdb/mlnode/data_access/enums.py
@@ -15,3 +15,15 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
+
+
+class DatasetType(Enum):
+ TIMESERIES = "timeseries"
+ WINDOW = "window"
+
+ def __str__(self):
+ return self.value
+
+ def __eq__(self, other: str) -> bool:
+ return self.value == other
diff --git a/mlnode/iotdb/mlnode/data_access/factory.py
b/mlnode/iotdb/mlnode/data_access/factory.py
new file mode 100644
index 0000000000..d0041388a6
--- /dev/null
+++ b/mlnode/iotdb/mlnode/data_access/factory.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from torch.utils.data import Dataset
+
+from iotdb.mlnode.data_access.enums import DatasetType
+from iotdb.mlnode.data_access.offline.dataset import (TimeSeriesDataset,
+ WindowDataset)
+from iotdb.mlnode.data_access.offline.source import (FileDataSource,
+ ThriftDataSource)
+from iotdb.mlnode.exception import BadConfigValueError, MissingConfigError
+
+support_forecasting_dataset = {
+ DatasetType.TIMESERIES: TimeSeriesDataset,
+ DatasetType.WINDOW: WindowDataset
+}
+
+
+def _dataset_config(**kwargs):
+ return {
+ 'time_embed': 'h',
+ **kwargs
+ }
+
+
+support_dataset_configs = {
+ DatasetType.TIMESERIES: _dataset_config(),
+ DatasetType.WINDOW: _dataset_config(
+ input_len=96,
+ pred_len=96,
+ )
+}
+
+
+def create_forecast_dataset(
+ source_type,
+ dataset_type,
+ **kwargs,
+) -> [Dataset, dict]:
+ """
+ Factory method for all support dataset
+ currently implement WindowDataset, TimeSeriesDataset
+ for specific dataset configs, see _dataset_config in
`algorithm/models/MODELNAME.py`
+
+ Args:
+ dataset_type: available choice in support_forecasting_dataset
+ source_type: available choice in ['file', 'thrift']
+ kwargs: for specific dataset configs, see returned `dataset_config`
with kwargs=None
+
+ Returns:
+ dataset: torch.nn.Module
+ dataset_config: dict of dataset configurations
+ """
+ if dataset_type not in support_forecasting_dataset.keys():
+ raise BadConfigValueError('dataset_type', dataset_type,
+ f'It should be one of
{list(support_forecasting_dataset.keys())}')
+
+ if source_type == 'file':
+ if 'filename' not in kwargs.keys():
+ raise MissingConfigError('filename')
+ datasource = FileDataSource(kwargs['filename'])
+ elif source_type == 'thrift':
+ if 'query_expressions' not in kwargs.keys():
+ raise MissingConfigError('query_expressions')
+ if 'query_filter' not in kwargs.keys():
+ raise MissingConfigError('query_filter')
+ datasource = ThriftDataSource(kwargs['query_expressions'],
kwargs['query_filter'])
+ else:
+ raise BadConfigValueError('source_type', source_type, "It should be
one of ['file', 'thrift]")
+
+ dataset_fn = support_forecasting_dataset[dataset_type]
+ dataset_config = support_dataset_configs[dataset_type]
+
+ for k, v in kwargs.items():
+ if k in dataset_config.keys():
+ dataset_config[k] = v
+
+ dataset = dataset_fn(datasource, **dataset_config)
+
+ if 'input_vars' in kwargs.keys() and dataset.get_variable_num() !=
kwargs['input_vars']:
+ raise BadConfigValueError('input_vars', kwargs['input_vars'],
+ f'Variable number of fetched data:
({dataset.get_variable_num()})'
+ f' should be consistent with input_vars')
+
+ data_config = dataset_config.copy()
+ data_config['input_vars'] = dataset.get_variable_num()
+ data_config['output_vars'] = dataset.get_variable_num()
+ data_config['source_type'] = source_type
+ data_config['dataset_type'] = dataset_type
+
+ return dataset, data_config
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
b/mlnode/iotdb/mlnode/data_access/offline/__init__.py
similarity index 100%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/offline/__init__.py
diff --git a/mlnode/iotdb/mlnode/datats/offline/dataset.py
b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
similarity index 85%
rename from mlnode/iotdb/mlnode/datats/offline/dataset.py
rename to mlnode/iotdb/mlnode/data_access/offline/dataset.py
index c71aaf87c5..1a96e81a4a 100644
--- a/mlnode/iotdb/mlnode/datats/offline/dataset.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/dataset.py
@@ -15,16 +15,10 @@
# specific language governing permissions and limitations
# under the License.
#
-
-
-import argparse
-
from torch.utils.data import Dataset
-from iotdb.mlnode.datats.offline.data_source import DataSource
-from iotdb.mlnode.datats.utils.timefeatures import time_features
-
-# currently support for multivariate forecasting only
+from iotdb.mlnode.data_access.offline.source import DataSource
+from iotdb.mlnode.data_access.utils.timefeatures import time_features
class TimeSeriesDataset(Dataset):
@@ -81,11 +75,11 @@ class WindowDataset(TimeSeriesDataset):
time_embed: str = 'h'):
self.input_len = input_len
self.pred_len = pred_len
- if input_len <= self.data.shape[0]:
+ super(WindowDataset, self).__init__(data_source, time_embed)
+ if input_len > self.data.shape[0]:
raise RuntimeError('input_len should not be larger than the number
of time series points')
- if pred_len <= self.data.shape[0]:
+ if pred_len > self.data.shape[0]:
raise RuntimeError('pred_len should not be larger than the number
of time series points')
- super(WindowDataset, self).__init__(data_source, time_embed)
def __getitem__(self, index):
s_begin = index
@@ -100,17 +94,3 @@ class WindowDataset(TimeSeriesDataset):
def __len__(self):
return len(self.data) - self.input_len - self.pred_len + 1
-
-
-def get_timeseries_dataset(data_config: argparse.Namespace) ->
TimeSeriesDataset:
- # TODO (@lcy)
- # init datasource
- # init dataset
- pass
-
-
-def get_window_dataset(data_config: argparse.Namespace) -> WindowDataset:
- # TODO (@lcy)
- # init datasource
- # init dataset
- pass
diff --git a/mlnode/iotdb/mlnode/datats/offline/data_source.py
b/mlnode/iotdb/mlnode/data_access/offline/source.py
similarity index 96%
rename from mlnode/iotdb/mlnode/datats/offline/data_source.py
rename to mlnode/iotdb/mlnode/data_access/offline/source.py
index cd8e9a891c..a63371ec7a 100644
--- a/mlnode/iotdb/mlnode/datats/offline/data_source.py
+++ b/mlnode/iotdb/mlnode/data_access/offline/source.py
@@ -33,6 +33,7 @@ class DataSource(object):
def __init__(self):
self.data = None
self.timestamp = None
+ self._read_data()
def _read_data(self):
raise NotImplementedError
@@ -46,9 +47,8 @@ class DataSource(object):
class FileDataSource(DataSource):
def __init__(self, filename: str = None):
- super(FileDataSource, self).__init__()
self.filename = filename
- self._read_data()
+ super(FileDataSource, self).__init__()
def _read_data(self):
try:
@@ -62,15 +62,14 @@ class FileDataSource(DataSource):
class ThriftDataSource(DataSource):
def __init__(self, query_expressions: list = None, query_filter: str =
None):
- super(DataSource, self).__init__()
self.query_expressions = query_expressions
self.query_filter = query_filter
- self._read_data()
+ super(ThriftDataSource, self).__init__()
def _read_data(self):
try:
data_client = client_manager.borrow_data_node_client()
- except Exception: # is this exception catch needed???
+ except Exception:
raise RuntimeError('Fail to establish connection with DataNode')
try:
diff --git a/mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
b/mlnode/iotdb/mlnode/data_access/utils/__init__.py
similarity index 100%
copy from mlnode/iotdb/mlnode/algorithm/models/forecast/__init__.py
copy to mlnode/iotdb/mlnode/data_access/utils/__init__.py
diff --git a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
similarity index 99%
rename from mlnode/iotdb/mlnode/datats/utils/timefeatures.py
rename to mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
index bd1681cfbf..ecd6784ca4 100644
--- a/mlnode/iotdb/mlnode/datats/utils/timefeatures.py
+++ b/mlnode/iotdb/mlnode/data_access/utils/timefeatures.py
@@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
#
-
-
from typing import List
import numpy as np
diff --git a/mlnode/iotdb/mlnode/exception.py b/mlnode/iotdb/mlnode/exception.py
index 3907a67d58..a7b211dbc2 100644
--- a/mlnode/iotdb/mlnode/exception.py
+++ b/mlnode/iotdb/mlnode/exception.py
@@ -16,7 +16,6 @@
# under the License.
#
-
class _BaseError(Exception):
"""Base class for exceptions in this module."""
pass
@@ -30,3 +29,18 @@ class BadNodeUrlError(_BaseError):
class ModelNotExistError(_BaseError):
def __init__(self, file_path: str):
self.message = "Model path: ({}) not exists".format(file_path)
+
+
+class BadConfigValueError(_BaseError):
+ def __init__(self, config_name: str, config_value, hint: str = ''):
+ self.message = "Bad value ({0}) for config: ({1}).
{2}".format(config_value, config_name, hint)
+
+
+class MissingConfigError(_BaseError):
+ def __init__(self, config_name: str):
+ self.message = "Missing config: ({})".format(config_name)
+
+
+class WrongTypeConfigError(_BaseError):
+ def __init__(self, config_name: str, expected_type: str):
+ self.message = "Wrong type for config: ({0}), expected:
({1})".format(config_name, expected_type)
diff --git a/mlnode/iotdb/mlnode/handler.py b/mlnode/iotdb/mlnode/handler.py
index d1f21ff517..e7ff76cbe0 100644
--- a/mlnode/iotdb/mlnode/handler.py
+++ b/mlnode/iotdb/mlnode/handler.py
@@ -16,7 +16,11 @@
# under the License.
#
+from iotdb.mlnode.algorithm.factory import create_forecast_model
from iotdb.mlnode.constant import TSStatusCode
+from iotdb.mlnode.data_access.factory import create_forecast_dataset
+from iotdb.mlnode.log import logger
+from iotdb.mlnode.parser import parse_training_request
from iotdb.mlnode.util import get_status
from iotdb.thrift.mlnode import IMLNodeRPCService
from iotdb.thrift.mlnode.ttypes import (TCreateTrainingTaskReq,
@@ -32,7 +36,28 @@ class MLNodeRPCServiceHandler(IMLNodeRPCService.Iface):
return get_status(TSStatusCode.SUCCESS_STATUS, "")
def createTrainingTask(self, req: TCreateTrainingTaskReq):
- return get_status(TSStatusCode.SUCCESS_STATUS, "")
+ # parse request stage (check required config and config type)
+ data_config, model_config, task_config = parse_training_request(req)
+
+ # create model stage (check model config legitimacy)
+ try:
+ model, model_config = create_forecast_model(**model_config)
+ except Exception as e: # Create model failed
+ return get_status(TSStatusCode.FAIL_STATUS, str(e))
+ logger.info('model config: ' + str(model_config))
+
+ # create data stage (check data config legitimacy)
+ try:
+ dataset, data_config = create_forecast_dataset(**data_config)
+ except Exception as e: # Create data failed
+ return get_status(TSStatusCode.FAIL_STATUS, str(e))
+ logger.info('data config: ' + str(data_config))
+
+ # create task stage (check task config legitimacy)
+
+ # submit task stage (check resource and decide pending/start)
+
+ return get_status(TSStatusCode.SUCCESS_STATUS, 'Successfully create
training task')
def forecast(self, req: TForecastReq):
status = get_status(TSStatusCode.SUCCESS_STATUS, "")
diff --git a/mlnode/iotdb/mlnode/parser.py b/mlnode/iotdb/mlnode/parser.py
new file mode 100644
index 0000000000..236032b9a0
--- /dev/null
+++ b/mlnode/iotdb/mlnode/parser.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+
+import argparse
+import re
+
+from iotdb.mlnode.algorithm.enums import ForecastTaskType
+from iotdb.mlnode.data_access.enums import DatasetType
+from iotdb.mlnode.exception import MissingConfigError, WrongTypeConfigError
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
+
+
+class _ConfigParser(argparse.ArgumentParser):
+ """
+ A parser for parsing configs from configs: dict
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def parse_configs(self, configs):
+ """
+ Parse configs from a dict
+ Args:configs: a dict of all configs which contains all required
arguments
+ Returns: a dict of parsed configs
+ """
+ args = self.parse_dict(configs)
+ return vars(self.parse_known_args(args)[0])
+
+ @staticmethod
+ def parse_dict(config_dict):
+ """
+ Parse a dict of configs to a list of arguments
+ Args:config_dict: a dict of configs
+ Returns: a list of arguments which can be parsed by argparse
+ """
+ args = []
+ for k, v in config_dict.items():
+ args.append("--{}".format(k))
+ if isinstance(v, str) and re.match(r'^\[(.*)]$', v):
+ v = eval(v)
+ v = [str(i) for i in v]
+ args.extend(v)
+ elif isinstance(v, list):
+ args.extend([str(i) for i in v])
+ else:
+ args.append(v)
+ return args
+
+ def error(self, message: str):
+ """
+ Override the error method to raise exceptions instead of exiting
+ """
+ if message.startswith('the following arguments are required:'):
+ missing_arg = re.findall(r': --(\w+)', message)[0]
+ raise MissingConfigError(missing_arg)
+ elif re.match(r'argument --\w+: invalid \w+ value:', message):
+ argument = re.findall(r'argument --(\w+):', message)[0]
+ expected_type = re.findall(r'invalid (\w+) value:', message)[0]
+ raise WrongTypeConfigError(argument, expected_type)
+ else:
+ raise Exception(message)
+
+
+""" Argument description:
+ - query_expressions: query expressions
+ - query_filter: query filter
+ - source_type: source type
+ - filename: filename
+ - dataset_type: dataset type
+ - time_embed: freq for time features encoding
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+"""
+_data_config_parser = _ConfigParser()
+_data_config_parser.add_argument('--source_type', type=str, required=True)
+_data_config_parser.add_argument('--dataset_type', type=DatasetType,
required=True)
+_data_config_parser.add_argument('--filename', type=str, default='')
+_data_config_parser.add_argument('--query_expressions', type=str, nargs='*',
default=[])
+_data_config_parser.add_argument('--query_filter', type=str, default='')
+_data_config_parser.add_argument('--time_embed', type=str, default='h')
+_data_config_parser.add_argument('--input_len', type=int, default=96)
+_data_config_parser.add_argument('--pred_len', type=int, default=96)
+_data_config_parser.add_argument('--input_vars', type=int, default=1)
+_data_config_parser.add_argument('--output_vars', type=int, default=1)
+
+""" Argument description:
+ - model_name: model name
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - task_type: task type, options:[M, S, MS];
+ M:multivariate predict multivariate,
+ S:univariate predict univariate,
+ MS:multivariate predict univariate'
+ - kernel_size: kernel size
+ - block_type: block type
+ - d_model: dimension of feature in model
+ - inner_layers: number of inner layers
+ - outer_layers: number of outer layers
+"""
+_model_config_parser = _ConfigParser()
+_model_config_parser.add_argument('--model_name', type=str, required=True)
+_model_config_parser.add_argument('--input_len', type=int, default=96)
+_model_config_parser.add_argument('--pred_len', type=int, default=96)
+_model_config_parser.add_argument('--input_vars', type=int, default=1)
+_model_config_parser.add_argument('--output_vars', type=int, default=1)
+_model_config_parser.add_argument('--forecast_task_type',
type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
+ choices=list(ForecastTaskType))
+_model_config_parser.add_argument('--kernel_size', type=int, default=25)
+_model_config_parser.add_argument('--block_type', type=str, default='generic')
+_model_config_parser.add_argument('--d_model', type=int, default=128)
+_model_config_parser.add_argument('--inner_layers', type=int, default=4)
+_model_config_parser.add_argument('--outer_layers', type=int, default=4)
+
+""" Argument description:
+ - model_id: model id
+ - tuning: whether to tune hyperparameters
+ - task_type: task type, options:[M, S, MS]; M:multivariate predict
multivariate, S:univariate predict univariate,
+ MS:multivariate predict univariate'
+ - task_class: task class
+ - input_len: input sequence length
+ - pred_len: prediction sequence length
+ - input_vars: number of input variables
+ - output_vars: number of output variables
+ - learning_rate: learning rate
+ - batch_size: batch size
+ - num_workers: number of workers
+ - epochs: number of epochs
+ - use_gpu: whether to use gpu
+ - use_multi_gpu: whether to use multi-gpu
+ - devices: devices to use
+ - metric_names: metric to use
+"""
+_task_config_parser = _ConfigParser()
+_task_config_parser.add_argument('--task_class', type=str, required=True)
+_task_config_parser.add_argument('--model_id', type=str, required=True)
+_task_config_parser.add_argument('--tuning', type=bool, default=False)
+_task_config_parser.add_argument('--forecast_task_type',
type=ForecastTaskType, default=ForecastTaskType.ENDOGENOUS,
+ choices=list(ForecastTaskType))
+_task_config_parser.add_argument('--input_len', type=int, default=96)
+_task_config_parser.add_argument('--pred_len', type=int, default=96)
+_task_config_parser.add_argument('--input_vars', type=int, default=1)
+_task_config_parser.add_argument('--output_vars', type=int, default=1)
+_task_config_parser.add_argument('--learning_rate', type=float, default=0.0001)
+_task_config_parser.add_argument('--batch_size', type=int, default=32)
+_task_config_parser.add_argument('--num_workers', type=int, default=0)
+_task_config_parser.add_argument('--epochs', type=int, default=10)
+_task_config_parser.add_argument('--use_gpu', type=bool, default=False)
+_task_config_parser.add_argument('--gpu', type=int, default=0)
+_task_config_parser.add_argument('--use_multi_gpu', type=bool, default=False)
+_task_config_parser.add_argument('--devices', type=int, nargs='+', default=[0])
+_task_config_parser.add_argument('--metric_names', type=str, nargs='+',
default=['MSE', 'MAE'])
+
+
+def parse_training_request(req: TCreateTrainingTaskReq):
+ """
+ Parse TCreateTrainingTaskReq with given yaml template
+ Args:
+ req: TCreateTrainingTaskReq
+ Returns:
+ data_config: configurations related to data
+ model_config: configurations related to model
+ task_config: configurations related to task
+ """
+ config = req.modelConfigs
+ config.update(model_id=req.modelId)
+ config.update(tuning=req.isAuto)
+ config.update(query_expressions=req.queryExpressions)
+ config.update(query_filter=req.queryFilter)
+
+ data_config = _data_config_parser.parse_configs(config)
+ model_config = _model_config_parser.parse_configs(config)
+ task_config = _task_config_parser.parse_configs(config)
+ return data_config, model_config, task_config
diff --git a/mlnode/iotdb/mlnode/serde.py b/mlnode/iotdb/mlnode/serde.py
index 26860faf38..5e98636e2e 100644
--- a/mlnode/iotdb/mlnode/serde.py
+++ b/mlnode/iotdb/mlnode/serde.py
@@ -15,10 +15,38 @@
# specific language governing permissions and limitations
# under the License.
#
+from enum import Enum
+
import numpy as np
import pandas as pd
-from iotdb.utils.IoTDBConstants import TSDataType
+
+class TSDataType(Enum):
+ BOOLEAN = 0
+ INT32 = 1
+ INT64 = 2
+ FLOAT = 3
+ DOUBLE = 4
+ TEXT = 5
+
+ # this method is implemented to avoid the issue reported by:
+ # https://bugs.python.org/issue30545
+ def __eq__(self, other) -> bool:
+ return self.value == other.value
+
+ def __hash__(self):
+ return self.value
+
+ def np_dtype(self):
+ return {
+ TSDataType.BOOLEAN: np.dtype(">?"),
+ TSDataType.FLOAT: np.dtype(">f4"),
+ TSDataType.DOUBLE: np.dtype(">f8"),
+ TSDataType.INT32: np.dtype(">i4"),
+ TSDataType.INT64: np.dtype(">i8"),
+ TSDataType.TEXT: np.dtype("str"),
+ }[self]
+
TIMESTAMP_STR = "Time"
START_INDEX = 2
diff --git a/mlnode/iotdb/mlnode/util.py b/mlnode/iotdb/mlnode/util.py
index c15e84da11..d67ba1290d 100644
--- a/mlnode/iotdb/mlnode/util.py
+++ b/mlnode/iotdb/mlnode/util.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
#
+
from iotdb.mlnode.constant import TSStatusCode
from iotdb.mlnode.exception import BadNodeUrlError
from iotdb.mlnode.log import logger
@@ -23,13 +24,10 @@ from iotdb.thrift.common.ttypes import TEndPoint, TSStatus
def parse_endpoint_url(endpoint_url: str) -> TEndPoint:
""" Parse TEndPoint from a given endpoint url.
-
Args:
endpoint_url: an endpoint url, format: ip:port
-
Returns:
TEndPoint
-
Raises:
BadNodeUrlError
"""
diff --git a/mlnode/test/test_parse_training_request.py
b/mlnode/test/test_parse_training_request.py
new file mode 100644
index 0000000000..ec318ae60d
--- /dev/null
+++ b/mlnode/test/test_parse_training_request.py
@@ -0,0 +1,136 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from iotdb.mlnode.parser import parse_training_request
+from iotdb.thrift.mlnode.ttypes import TCreateTrainingTaskReq
+
+
+def test_parse_training_request():
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window',
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'model_name': 'dlinear',
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'm',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ data_config, model_config, task_config = parse_training_request(req)
+ for config in model_configs:
+ if config in data_config:
+ assert data_config[config] == model_configs[config]
+ if config in model_config:
+ assert model_config[config] == model_configs[config]
+ if config in task_config:
+ assert task_config[config] == model_configs[config]
+
+
+def test_missing_argument():
+ # missing model_name
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window',
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96,
+ 'pred_len': 96,
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'm',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ try:
+ data_config, model_config, task_config = parse_training_request(req)
+ except Exception as e:
+ assert e.message == 'Missing config: (model_name)'
+
+
+def test_wrong_argument_type():
+ model_id = 'mid_etth1_dlinear_default'
+ is_auto = False
+ model_configs = {
+ 'task_class': 'forecast_training_task',
+ 'source_type': 'thrift',
+ 'dataset_type': 'window',
+ 'filename': 'ETTh1.csv',
+ 'time_embed': 'h',
+ 'input_len': 96.7,
+ 'pred_len': 96,
+ 'model_name': 'dlinear',
+ 'input_vars': 7,
+ 'output_vars': 7,
+ 'forecast_type': 'm',
+ 'kernel_size': 25,
+ 'learning_rate': 1e-3,
+ 'batch_size': 32,
+ 'num_workers': 0,
+ 'epochs': 10,
+ 'metric_names': ['MSE', 'MAE']
+ }
+ query_expressions = ['root.eg.etth1.**', 'root.eg.etth1.**',
'root.eg.etth1.**']
+ query_filter = '0,1501516800000'
+ req = TCreateTrainingTaskReq(
+ modelId=str(model_id),
+ isAuto=is_auto,
+ modelConfigs={k: str(v) for k, v in model_configs.items()},
+ queryExpressions=[str(query) for query in query_expressions],
+ queryFilter=str(query_filter),
+ )
+ try:
+ data_config, model_config, task_config = parse_training_request(req)
+ except Exception as e:
+ message = "Wrong type for config: ({})".format('input_len')
+ message += ", expected: ({})".format('int')
+ assert e.message == message