This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new 5c4c297e63d [AINode] Modify dataset module for table-model training 
(#15816)
5c4c297e63d is described below

commit 5c4c297e63d794b50a946b9fc9a4e07acc6a49f6
Author: YangCaiyin <[email protected]>
AuthorDate: Thu Jun 26 20:26:46 2025 +0800

    [AINode] Modify dataset module for table-model training (#15816)
---
 iotdb-core/ainode/ainode/core/ingress/iotdb.py     | 152 ++++++++++-----------
 .../iotdb/confignode/manager/ConfigManager.java    |  35 +----
 .../execution/config/TableConfigTaskVisitor.java   |  17 +--
 .../execution/config/TreeConfigTaskVisitor.java    |   1 -
 .../config/executor/ClusterConfigTaskExecutor.java |   8 +-
 .../config/executor/IConfigTaskExecutor.java       |   4 +-
 .../config/metadata/ai/CreateTrainingTask.java     |  26 ++--
 .../plan/relational/sql/ast/CreateTraining.java    |  58 ++------
 .../plan/relational/sql/parser/AstBuilder.java     |  48 +------
 .../db/relational/grammar/sql/RelationalSql.g4     |  24 +---
 .../src/main/thrift/confignode.thrift              |  10 +-
 11 files changed, 112 insertions(+), 271 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/ingress/iotdb.py 
b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
index 175e2eb2f65..399036bca1d 100644
--- a/iotdb-core/ainode/ainode/core/ingress/iotdb.py
+++ b/iotdb-core/ainode/ainode/core/ingress/iotdb.py
@@ -216,32 +216,24 @@ class IoTDBTreeModelDataset(BasicDatabaseForecastDataset):
 
 class IoTDBTableModelDataset(BasicDatabaseForecastDataset):
 
+    DEFAULT_TAG = "__DEFAULT_TAG__"
+
     def __init__(
         self,
-        input_len: int,
-        out_len: int,
+        model_id: str,
+        seq_len: int,
+        input_token_len: int,
+        output_token_len: int,
         data_schema_list: list,
         ip: str = "127.0.0.1",
         port: int = 6667,
         username: str = "root",
         password: str = "root",
         time_zone: str = "UTC+8",
-        start_split: float = 0,
-        end_split: float = 1,
+        use_rate: float = 1.0,
+        offset_rate: float = 0.0,
     ):
-        super().__init__(ip, port, input_len, out_len)
-        if end_split < start_split:
-            raise ValueError("end_split must be greater than start_split")
-
-        # database , table
-        self.SELECT_SERIES_FORMAT_SQL = "select distinct item_id from %s"
-        self.COUNT_SERIES_LENGTH_SQL = (
-            "select count(value) from %s where item_id = '%s'"
-        )
-        self.FETCH_SERIES_SQL = (
-            "select value from %s where item_id = '%s' offset %s limit %s"
-        )
-        self.SERIES_NAME = "%s.%s"
+        super().__init__(ip, port, seq_len, input_token_len, output_token_len)
 
         table_session_config = TableSessionConfig(
             node_urls=[f"{ip}:{port}"],
@@ -249,87 +241,95 @@ class 
IoTDBTableModelDataset(BasicDatabaseForecastDataset):
             password=password,
             time_zone=time_zone,
         )
-
         self.session = TableSession(table_session_config)
-        self.context_length = self.input_len + self.output_len
-        self.token_num = self.context_length // self.input_len
-        self._fetch_schema(data_schema_list)
+        self.use_rate = use_rate
+        self.offset_rate = offset_rate
 
-        self.start_index = int(self.total_count * start_split)
-        self.end_index = self.total_count * end_split
+        # used for caching data
+        self._fetch_schema(data_schema_list)
 
     def _fetch_schema(self, data_schema_list: list):
-        series_to_length = {}
-        for data_schema in data_schema_list:
-            series_list = []
-            with self.session.execute_query_statement(
-                self.SELECT_SERIES_FORMAT_SQL % data_schema
-            ) as show_devices_result:
-                while show_devices_result.has_next():
+        series_map = {}
+        for target_sql in data_schema_list:
+            target_sql = target_sql.schemaName
+            with self.session.execute_query_statement(target_sql) as 
target_data:
+                while target_data.has_next():
+                    cur_data = target_data.next()
+                    # TODO: currently, we only support the following simple 
table form
+                    time_col, value_col, tag_col = -1, -1, -1
+                    for i, field in enumerate(cur_data.get_fields()):
+                        if field.get_data_type() == TSDataType.TIMESTAMP:
+                            time_col = i
+                        elif field.get_data_type() in (
+                            TSDataType.INT32,
+                            TSDataType.INT64,
+                            TSDataType.FLOAT,
+                            TSDataType.DOUBLE,
+                        ):
+                            value_col = i
+                        elif field.get_data_type() == TSDataType.TEXT:
+                            tag_col = i
+                    if time_col == -1 or value_col == -1:
+                        raise ValueError(
+                            "The training cannot start due to invalid data 
schema"
+                        )
+                    if tag_col == -1:
+                        tag = self.DEFAULT_TAG
+                    else:
+                        tag = cur_data.get_fields()[tag_col].get_string_value()
+                    if tag not in series_map:
+                        series_map[tag] = []
+                    series_list = series_map[tag]
                     series_list.append(
-                        
get_field_value(show_devices_result.next().get_fields()[0])
+                        get_field_value(cur_data.get_fields()[value_col])
                     )
 
-            for series in series_list:
-                with self.session.execute_query_statement(
-                    self.COUNT_SERIES_LENGTH_SQL % (data_schema.schemaName, 
series)
-                ) as count_series_result:
-                    length = 
get_field_value(count_series_result.next().get_fields()[0])
-                    series_to_length[
-                        self.SERIES_NAME % (data_schema.schemaName, series)
-                    ] = length
-
-        sorted_series = sorted(series_to_length.items(), key=lambda x: x[1])
-        sorted_series_with_prefix_sum = []
+        # TODO: Unify the following implementation
+        # structure: [(series_name, the number of windows of this series, 
prefix sum of window number, window start offset, series_data), ...]
+        series_with_prefix_sum = []
         window_sum = 0
-        for seq_name, seq_length in sorted_series:
-            window_count = seq_length - self.context_length + 1
-            if window_count < 0:
+        for seq_name, seq_values in series_map.items():
+            # calculate and sum the number of training data windows for each 
time series
+            window_count = len(seq_values) - self.seq_len - 
self.output_token_len + 1
+            if window_count <= 1:
                 continue
-            window_sum += window_count
-            sorted_series_with_prefix_sum.append((seq_name, window_count, 
window_sum))
+            use_window_count = int(window_count * self.use_rate)
+            window_sum += use_window_count
+            series_with_prefix_sum.append(
+                (
+                    seq_name,
+                    use_window_count,
+                    window_sum,
+                    int(window_count * self.offset_rate),
+                    seq_values,
+                )
+            )
 
-        self.total_count = window_sum
-        self.sorted_series = sorted_series_with_prefix_sum
+        self.total_window_count = window_sum
+        self.series_with_prefix_sum = series_with_prefix_sum
 
     def __getitem__(self, index):
         window_index = index
-
+        # locate the series to be queried
         series_index = 0
-
-        while self.sorted_series[series_index][2] < window_index:
+        while self.series_with_prefix_sum[series_index][1] < window_index:
             series_index += 1
-
+        # locate the window of this series to be queried
         if series_index != 0:
-            window_index -= self.sorted_series[series_index - 1][2]
-
-        if window_index != 0:
-            window_index -= 1
-        series = self.sorted_series[series_index][0]
-        schema = series.split(".")
-
-        result = []
-        sql = self.FETCH_SERIES_SQL % (
-            schema[0:1],
-            schema[2],
-            window_index,
-            self.context_length,
-        )
-        try:
-            with self.session.execute_query_statement(sql) as query_result:
-                while query_result.has_next():
-                    
result.append(get_field_value(query_result.next().get_fields()[0]))
-        except Exception as e:
-            logger.error("Executing sql: {} with exception: {}".format(sql, e))
+            window_index -= self.series_with_prefix_sum[series_index - 1][2]
+        window_index += self.series_with_prefix_sum[series_index][3]
+        result = self.series_with_prefix_sum[series_index][4][
+            window_index : window_index + self.seq_len + self.output_token_len
+        ]
         result = torch.tensor(result)
         return (
-            result[0 : self.input_len],
-            result[-self.output_len :],
+            result[0 : self.seq_len],
+            result[self.input_token_len : self.seq_len + 
self.output_token_len],
             np.ones(self.token_num, dtype=np.int32),
         )
 
     def __len__(self):
-        return self.end_index - self.start_index
+        return self.total_window_count
 
 
 def register_dataset(key: str, dataset: Dataset):
diff --git 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 0c8f87c12b5..a6462a72a2a 100644
--- 
a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++ 
b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -167,7 +167,6 @@ import 
org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
 import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq;
 import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartResp;
 import org.apache.iotdb.confignode.rpc.thrift.TDataPartitionTableResp;
-import org.apache.iotdb.confignode.rpc.thrift.TDataSchemaForTable;
 import org.apache.iotdb.confignode.rpc.thrift.TDatabaseSchema;
 import org.apache.iotdb.confignode.rpc.thrift.TDeactivateSchemaTemplateReq;
 import org.apache.iotdb.confignode.rpc.thrift.TDeleteDatabasesReq;
@@ -248,7 +247,6 @@ import 
org.apache.iotdb.confignode.rpc.thrift.TSpaceQuotaResp;
 import org.apache.iotdb.confignode.rpc.thrift.TStartPipeReq;
 import org.apache.iotdb.confignode.rpc.thrift.TStopPipeReq;
 import org.apache.iotdb.confignode.rpc.thrift.TSubscribeReq;
-import org.apache.iotdb.confignode.rpc.thrift.TTableInfo;
 import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
 import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
 import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
@@ -2641,10 +2639,6 @@ public class ConfigManager implements IManager {
 
   private List<IDataSchema> fetchSchemaForTreeModel(TCreateTrainingReq req) {
     List<IDataSchema> dataSchemaList = new ArrayList<>();
-    if (req.useAllData) {
-      dataSchemaList.add(new IDataSchema("root.**"));
-      return dataSchemaList;
-    }
     for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) {
       IDataSchema dataSchema = new 
IDataSchema(req.getDataSchemaForTree().getPath().get(i));
       dataSchema.setTimeRange(req.getTimeRanges().get(i));
@@ -2654,28 +2648,7 @@ public class ConfigManager implements IManager {
   }
 
   private List<IDataSchema> fetchSchemaForTableModel(TCreateTrainingReq req) {
-    List<IDataSchema> dataSchemaList = new ArrayList<>();
-    TDataSchemaForTable dataSchemaForTable = req.getDataSchemaForTable();
-    if (req.useAllData || !dataSchemaForTable.getDatabaseList().isEmpty()) {
-      List<String> databaseNameList = new ArrayList<>();
-      if (req.useAllData) {
-        TShowDatabaseResp resp = showDatabase(new TGetDatabaseReq());
-        databaseNameList.addAll(resp.getDatabaseInfoMap().keySet());
-      } else {
-        databaseNameList.addAll(dataSchemaForTable.getDatabaseList());
-      }
-
-      for (String database : databaseNameList) {
-        TShowTableResp resp = showTables(database, false);
-        for (TTableInfo tableInfo : resp.getTableInfoList()) {
-          dataSchemaList.add(new IDataSchema(database + DOT + 
tableInfo.tableName));
-        }
-      }
-    }
-    for (String tableName : dataSchemaForTable.getTableList()) {
-      dataSchemaList.add(new IDataSchema(tableName));
-    }
-    return dataSchemaList;
+    return Collections.singletonList(new 
IDataSchema(req.getDataSchemaForTable().getTargetSql()));
   }
 
   public TSStatus createTraining(TCreateTrainingReq req) {
@@ -2687,11 +2660,11 @@ public class ConfigManager implements IManager {
 
     TTrainingReq trainingReq = new TTrainingReq();
     trainingReq.setModelId(req.getModelId());
-    trainingReq.setModelType("sundial");
-    if (req.existingModelId != null) {
+    trainingReq.setModelType(req.getModelType());
+    if (req.isSetExistingModelId()) {
       trainingReq.setExistingModelId(req.getExistingModelId());
     }
-    if (!req.parameters.isEmpty()) {
+    if (req.isSetParameters() && !req.getParameters().isEmpty()) {
       trainingReq.setParameters(req.getParameters());
     }
 
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
index 95d9ecaa278..10b75c424cc 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java
@@ -211,7 +211,6 @@ import org.apache.tsfile.enums.TSDataType;
 import org.apache.tsfile.utils.Binary;
 import org.apache.tsfile.utils.Pair;
 
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
@@ -1359,26 +1358,12 @@ public class TableConfigTaskVisitor extends 
AstVisitor<IConfigTask, MPPQueryCont
   protected IConfigTask visitCreateTraining(CreateTraining node, 
MPPQueryContext context) {
     context.setQueryType(QueryType.WRITE);
 
-    String curDatabase = clientSession.getDatabaseName();
-    List<String> tableList = new ArrayList<>();
-    for (QualifiedName tableName : node.getTargetTables()) {
-      List<String> parts = tableName.getParts();
-      if (parts.size() == 1) {
-        tableList.add(curDatabase + "." + parts.get(0));
-      } else {
-        tableList.add(parts.get(1) + "." + parts.get(0));
-      }
-    }
-
     return new CreateTrainingTask(
         node.getModelId(),
         node.getModelType(),
         node.getParameters(),
-        node.isUseAllData(),
-        node.getTargetTimeRanges(),
         node.getExistingModelId(),
-        tableList,
-        node.getTargetDbs());
+        node.getTargetSql());
   }
 
   @Override
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
index 282839ec560..5b23baa41b3 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java
@@ -810,7 +810,6 @@ public class TreeConfigTaskVisitor extends 
StatementVisitor<IConfigTask, MPPQuer
         createTrainingStatement.getModelId(),
         createTrainingStatement.getModelType(),
         createTrainingStatement.getParameters(),
-        false,
         createTrainingStatement.getTargetTimeRanges(),
         createTrainingStatement.getExistingModelId(),
         targetPathPatterns);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
index 39c2bda7298..f45c3ce187b 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java
@@ -3334,11 +3334,9 @@ public class ClusterConfigTaskExecutor implements 
IConfigTaskExecutor {
       String modelType,
       boolean isTableModel,
       Map<String, String> parameters,
-      boolean useAllData,
       List<List<Long>> timeRanges,
       String existingModelId,
-      @Nullable List<String> tableList,
-      @Nullable List<String> databaseList,
+      @Nullable String targetSql,
       @Nullable List<String> pathList) {
     final SettableFuture<ConfigTaskResult> future = SettableFuture.create();
     try (final ConfigNodeClient client =
@@ -3347,8 +3345,7 @@ public class ClusterConfigTaskExecutor implements 
IConfigTaskExecutor {
 
       if (isTableModel) {
         TDataSchemaForTable dataSchemaForTable = new TDataSchemaForTable();
-        dataSchemaForTable.setTableList(tableList);
-        dataSchemaForTable.setDatabaseList(databaseList);
+        dataSchemaForTable.setTargetSql(targetSql);
         req.setDataSchemaForTable(dataSchemaForTable);
       } else {
         TDataSchemaForTree dataSchemaForTree = new TDataSchemaForTree();
@@ -3356,7 +3353,6 @@ public class ClusterConfigTaskExecutor implements 
IConfigTaskExecutor {
         req.setDataSchemaForTree(dataSchemaForTree);
       }
       req.setParameters(parameters);
-      req.setUseAllData(useAllData);
       req.setTimeRanges(timeRanges);
       req.setExistingModelId(existingModelId);
       final TSStatus executionStatus = client.createTraining(req);
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
index 8120dca2b92..cb49b444a52 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java
@@ -427,10 +427,8 @@ public interface IConfigTaskExecutor {
       String modelType,
       boolean isTableModel,
       Map<String, String> parameters,
-      boolean useAllData,
       List<List<Long>> timeRanges,
       String existingModelId,
-      @Nullable List<String> tableList,
-      @Nullable List<String> databaseList,
+      @Nullable String targetSql,
       @Nullable List<String> pathList);
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
index 91d3258dba1..821c01e27a4 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java
@@ -34,49 +34,41 @@ public class CreateTrainingTask implements IConfigTask {
   private final String modelType;
   private final boolean isTableModel;
   private final Map<String, String> parameters;
-  private final boolean useAllData;
-  private final List<List<Long>> timeRanges;
+
   private final String existingModelId;
 
   // Data schema for table model
-  private List<String> targetTables;
-  private List<String> targetDbs;
+  private String targetSql = null;
   // Data schema for tree model
   private List<String> targetPaths;
+  private List<List<Long>> timeRanges;
 
+  // For table model
   public CreateTrainingTask(
       String modelId,
       String modelType,
       Map<String, String> parameters,
-      boolean useAllData,
-      List<List<Long>> timeRanges,
       String existingModelId,
-      List<String> targetTables,
-      List<String> targetDbs) {
+      String targetSql) {
     this.modelId = modelId;
     this.modelType = modelType;
     this.parameters = parameters;
-    this.useAllData = useAllData;
-    this.timeRanges = timeRanges;
     this.existingModelId = existingModelId;
-
+    this.targetSql = targetSql;
     this.isTableModel = true;
-    this.targetTables = targetTables;
-    this.targetDbs = targetDbs;
   }
 
+  // For tree model
   public CreateTrainingTask(
       String modelId,
       String modelType,
       Map<String, String> parameters,
-      boolean useAllData,
       List<List<Long>> timeRanges,
       String existingModelId,
       List<String> targetPaths) {
     this.modelId = modelId;
     this.modelType = modelType;
     this.parameters = parameters;
-    this.useAllData = useAllData;
     this.timeRanges = timeRanges;
     this.existingModelId = existingModelId;
 
@@ -92,11 +84,9 @@ public class CreateTrainingTask implements IConfigTask {
         modelType,
         isTableModel,
         parameters,
-        useAllData,
         timeRanges,
         existingModelId,
-        targetTables,
-        targetDbs,
+        targetSql,
         targetPaths);
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
index 3c978ccb5c6..1e621b7352e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/CreateTraining.java
@@ -27,20 +27,16 @@ public class CreateTraining extends Statement {
 
   private final String modelId;
   private final String modelType;
+  private final String targetSql;
 
   private Map<String, String> parameters;
   private String existingModelId = null;
 
-  private List<QualifiedName> targetTables;
-  private List<String> targetDbs;
-
-  private List<List<Long>> targetTimeRanges;
-  private boolean useAllData = false;
-
-  public CreateTraining(String modelId, String modelType) {
+  public CreateTraining(String modelId, String modelType, String targetSql) {
     super(null);
     this.modelId = modelId;
     this.modelType = modelType;
+    this.targetSql = targetSql;
   }
 
   @Override
@@ -56,26 +52,6 @@ public class CreateTraining extends Statement {
     this.existingModelId = existingModelId;
   }
 
-  public void setTargetDbs(List<String> targetDbs) {
-    this.targetDbs = targetDbs;
-  }
-
-  public void setTargetTables(List<QualifiedName> targetTables) {
-    this.targetTables = targetTables;
-  }
-
-  public void setUseAllData(boolean useAllData) {
-    this.useAllData = useAllData;
-  }
-
-  public List<String> getTargetDbs() {
-    return targetDbs;
-  }
-
-  public List<QualifiedName> getTargetTables() {
-    return targetTables;
-  }
-
   public String getModelId() {
     return modelId;
   }
@@ -92,16 +68,8 @@ public class CreateTraining extends Statement {
     return existingModelId;
   }
 
-  public boolean isUseAllData() {
-    return useAllData;
-  }
-
-  public void setTargetTimeRanges(List<List<Long>> targetTimeRanges) {
-    this.targetTimeRanges = targetTimeRanges;
-  }
-
-  public List<List<Long>> getTargetTimeRanges() {
-    return targetTimeRanges;
+  public String getTargetSql() {
+    return targetSql;
   }
 
   @Override
@@ -111,8 +79,7 @@ public class CreateTraining extends Statement {
 
   @Override
   public int hashCode() {
-    return Objects.hash(
-        modelId, modelType, existingModelId, parameters, targetTimeRanges, 
useAllData);
+    return Objects.hash(modelId, modelType, targetSql, existingModelId, 
parameters);
   }
 
   @Override
@@ -125,8 +92,7 @@ public class CreateTraining extends Statement {
         && modelType.equals(createTraining.modelType)
         && Objects.equals(existingModelId, createTraining.existingModelId)
         && Objects.equals(parameters, createTraining.parameters)
-        && Objects.equals(targetTimeRanges, createTraining.targetTimeRanges)
-        && useAllData == createTraining.useAllData;
+        && Objects.equals(targetSql, createTraining.targetSql);
   }
 
   @Override
@@ -143,14 +109,8 @@ public class CreateTraining extends Statement {
         + ", existingModelId='"
         + existingModelId
         + '\''
-        + ", targetTables="
-        + targetTables
-        + ", targetDbs="
-        + targetDbs
-        + ", targetTimeRanges="
-        + targetTimeRanges
-        + ", useAllData="
-        + useAllData
+        + ", targetSql='"
+        + targetSql
         + '}';
   }
 }
diff --git 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
index 268379e6b12..85085baea5e 100644
--- 
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
+++ 
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java
@@ -3554,20 +3554,17 @@ public class AstBuilder extends 
RelationalSqlBaseVisitor<Node> {
     }
   }
 
-  private List<Long> parseTimePair(RelationalSqlParser.TimeRangeContext 
timeRangeContext) {
-    long currentTime = CommonDateTimeUtils.currentTime();
-    List<Long> timeRange = new ArrayList<>();
-    timeRange.add(parseTimeValue(timeRangeContext.timeValue(0), currentTime));
-    timeRange.add(parseTimeValue(timeRangeContext.timeValue(1), currentTime));
-    return timeRange;
-  }
-
   @Override
   public Node 
visitCreateModelStatement(RelationalSqlParser.CreateModelStatementContext ctx) {
     String modelId = ctx.modelId.getText();
     validateModelName(modelId);
     String modelType = ctx.modelType.getText();
-    CreateTraining createTraining = new CreateTraining(modelId, modelType);
+
+    if (ctx.targetData == null) {
+      throw new SemanticException("Target data in sql should be set in CREATE 
MODEL");
+    }
+    String targetData = ((StringLiteral) visit(ctx.targetData)).getValue();
+    CreateTraining createTraining = new CreateTraining(modelId, modelType, 
targetData);
     if (ctx.HYPERPARAMETERS() != null) {
       Map<String, String> parameters = new HashMap<>();
       for (RelationalSqlParser.HparamPairContext hparamPairContext : 
ctx.hparamPair()) {
@@ -3581,39 +3578,6 @@ public class AstBuilder extends 
RelationalSqlBaseVisitor<Node> {
       createTraining.setExistingModelId(ctx.existingModelId.getText());
     }
 
-    List<List<Long>> dbTimeRange = new ArrayList<>();
-    List<List<Long>> tableTimeRange = new ArrayList<>();
-    if (ctx.trainingData().ALL() != null) {
-      createTraining.setUseAllData(true);
-    } else {
-      List<QualifiedName> targetTables = new ArrayList<>();
-      List<String> targetDbs = new ArrayList<>();
-      for (RelationalSqlParser.DataElementContext dataElementContext :
-          ctx.trainingData().dataElement()) {
-        if (dataElementContext.databaseElement() != null) {
-          targetDbs.add(
-              ((Identifier) 
visit(dataElementContext.databaseElement().database)).getValue());
-          if (dataElementContext.databaseElement().timeRange() != null) {
-            
dbTimeRange.add(parseTimePair(dataElementContext.databaseElement().timeRange()));
-          }
-        } else {
-          
targetTables.add(getQualifiedName(dataElementContext.tableElement().qualifiedName()));
-          if (dataElementContext.tableElement().timeRange() != null) {
-            
tableTimeRange.add(parseTimePair(dataElementContext.tableElement().timeRange()));
-          }
-        }
-      }
-
-      if (targetDbs.isEmpty() && targetTables.isEmpty()) {
-        throw new IllegalArgumentException(
-            "No training data is supported for model, please indicate database 
or table");
-      }
-      createTraining.setTargetDbs(targetDbs);
-      createTraining.setTargetTables(targetTables);
-
-      dbTimeRange.addAll(tableTimeRange);
-      createTraining.setTargetTimeRanges(dbTimeRange);
-    }
     return createTraining;
   }
 
diff --git 
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
 
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
index 56b2687b064..591592cfdbb 100644
--- 
a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
+++ 
b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4
@@ -782,29 +782,7 @@ revokeGrantOpt
 // ------------------------------------------- AI 
---------------------------------------------------------
 
 createModelStatement
-    : CREATE MODEL modelType=identifier modelId=identifier (WITH 
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL 
existingModelId=identifier)? ON DATASET '(' trainingData ')'
-    ;
-
-trainingData
-    : ALL
-    | dataElement(',' dataElement)*
-    ;
-
-dataElement
-    : databaseElement
-    | tableElement
-    ;
-
-databaseElement
-    : DATABASE database=identifier ('(' timeRange ')')?
-    ;
-
-tableElement
-    : TABLE tableName=qualifiedName ('(' timeRange ')')?
-    ;
-
-timeRange
-    : '[' startTime=timeValue ',' endTime=timeValue ']'
+    : CREATE MODEL modelType=identifier modelId=identifier (WITH 
HYPERPARAMETERS '(' hparamPair (',' hparamPair)* ')')? (FROM MODEL 
existingModelId=identifier)? ON DATASET '(' targetData=string ')'
     ;
 
 hparamPair
diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift 
b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
index f767f35d67c..e5d599ab632 100644
--- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
+++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift
@@ -1086,8 +1086,7 @@ struct TUpdateModelInfoReq {
 }
 
 struct TDataSchemaForTable{
-    1: required list<string> databaseList
-    2: required list<string> tableList
+    1: required string targetSql
 }
 
 struct TDataSchemaForTree{
@@ -1100,10 +1099,9 @@ struct TCreateTrainingReq {
     3: required bool isTableModel
     4: optional TDataSchemaForTable dataSchemaForTable
     5: optional TDataSchemaForTree dataSchemaForTree
-    6: optional bool useAllData
-    7: optional map<string, string> parameters
-    8: optional string existingModelId
-    9: optional list<list<i64>> timeRanges
+    6: optional map<string, string> parameters
+    7: optional string existingModelId
+    8: optional list<list<i64>> timeRanges
 }
 
 // ====================================================

Reply via email to