This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 74a75fd0568 Make foreacast tvf's table relation default order by
timecol desc
74a75fd0568 is described below
commit 74a75fd056858e42544259e45385585d701f48fc
Author: Jackie Tien <[email protected]>
AuthorDate: Fri May 16 14:01:03 2025 +0800
Make foreacast tvf's table relation default order by timecol desc
---
.../relational/analyzer/StatementAnalyzer.java | 100 ++++++++++++++++++++-
.../function/tvf/ForecastTableFunction.java | 77 +++++++++++++---
.../plan/relational/metadata/Metadata.java | 6 ++
.../relational/metadata/TableMetadataImpl.java | 9 ++
.../distribute/TableDistributedPlanGenerator.java | 41 +++++++--
.../plan/relational/planner/node/GroupNode.java | 2 +-
.../sql/ast/TableFunctionTableArgument.java | 6 +-
.../plan/relational/analyzer/TSBSMetadata.java | 6 ++
.../relational/analyzer/TableFunctionTest.java | 72 +++++++++++++++
.../plan/relational/analyzer/TestMetadata.java | 47 ++++++++--
.../{SortMatcher.java => GroupMatcher.java} | 42 +++++----
.../planner/assertions/PlanMatchPattern.java | 14 ++-
.../relational/planner/assertions/SortMatcher.java | 4 +-
.../planner/assertions/TableScanMatcher.java | 1 -
.../{SortMatcher.java => TopKMatcher.java} | 33 ++++---
15 files changed, 386 insertions(+), 74 deletions(-)
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java
index 3633c588ab2..b70135e51c0 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java
@@ -38,6 +38,8 @@ import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.Ar
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.ArgumentsAnalysis;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableArgumentAnalysis;
import
org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableFunctionInvocationAnalysis;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
import
org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName;
@@ -157,6 +159,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SortItem;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StartPipe;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StopPipe;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral;
import
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SubqueryExpression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Table;
@@ -4059,7 +4062,14 @@ public class StatementAnalyzer {
@Override
public Scope visitTableFunctionInvocation(TableFunctionInvocation node,
Optional<Scope> scope) {
- TableFunction function =
metadata.getTableFunction(node.getName().toString());
+ String functionName = node.getName().toString();
+ TableFunction function = metadata.getTableFunction(functionName);
+
+ // set model fetcher for ForecastTableFunction
+ if (function instanceof ForecastTableFunction) {
+ ((ForecastTableFunction)
function).setModelFetcher(metadata.getModelFetcher());
+ }
+
Node errorLocation = node;
if (!node.getArguments().isEmpty()) {
errorLocation = node.getArguments().get(0);
@@ -4067,7 +4077,11 @@ public class StatementAnalyzer {
ArgumentsAnalysis argumentsAnalysis =
analyzeArguments(
- function.getArgumentsSpecifications(), node.getArguments(),
scope, errorLocation);
+ function.getArgumentsSpecifications(),
+ node.getArguments(),
+ scope,
+ errorLocation,
+ functionName);
TableFunctionAnalysis functionAnalysis;
try {
@@ -4213,7 +4227,8 @@ public class StatementAnalyzer {
List<ParameterSpecification> parameterSpecifications,
List<TableFunctionArgument> arguments,
Optional<Scope> scope,
- Node errorLocation) {
+ Node errorLocation,
+ String functionName) {
if (parameterSpecifications.size() < arguments.size()) {
throw new SemanticException(
String.format(
@@ -4249,6 +4264,11 @@ public class StatementAnalyzer {
"Duplicate argument specification for name: " +
parameterSpecification.getName());
}
}
+
+ // append order by time asc for built-in forecast tvf if user doesn't
specify order by
+ // clause
+ tryUpdateOrderByForForecastByName(functionName, arguments,
argumentSpecificationsByName);
+
Set<String> uniqueArgumentNames = new HashSet<>();
Set<String> specifiedArgumentNames =
ImmutableSet.copyOf(argumentSpecificationsByName.keySet());
@@ -4280,6 +4300,9 @@ public class StatementAnalyzer {
analyzeDefault(parameterSpecification, errorLocation));
}
} else {
+ // append order by time asc for built-in forecast tvf if user doesn't
specify order by
+ // clause
+ tryUpdateOrderByForForecastByPosition(functionName, arguments,
parameterSpecifications);
for (int i = 0; i < arguments.size(); i++) {
TableFunctionArgument argument = arguments.get(i);
ParameterSpecification parameterSpecification =
parameterSpecifications.get(i);
@@ -4299,6 +4322,77 @@ public class StatementAnalyzer {
return new ArgumentsAnalysis(passedArguments.buildOrThrow(),
tableArgumentAnalyses.build());
}
+ // append order by time asc for built-in forecast tvf if user doesn't
specify order by clause
+ private void tryUpdateOrderByForForecastByName(
+ String functionName,
+ List<TableFunctionArgument> arguments,
+ Map<String, ParameterSpecification> argumentSpecificationsByName) {
+ if
(TableBuiltinTableFunction.FORECAST.getFunctionName().equalsIgnoreCase(functionName))
{
+ String timeColumn =
+ (String)
+ argumentSpecificationsByName
+ .get(ForecastTableFunction.TIMECOL_PARAMETER_NAME)
+ .getDefaultValue()
+ .get();
+ for (TableFunctionArgument argument : arguments) {
+ if (ForecastTableFunction.TIMECOL_PARAMETER_NAME.equalsIgnoreCase(
+ argument.getName().get().getValue())) {
+ if (argument.getValue() instanceof StringLiteral) {
+ timeColumn = ((StringLiteral) argument.getValue()).getValue();
+ }
+ }
+ }
+ tryUpdateOrderByForForecast(arguments, timeColumn);
+ }
+ }
+
+ // append order by time asc for built-in forecast tvf if user doesn't
specify order by clause
+ private void tryUpdateOrderByForForecastByPosition(
+ String functionName,
+ List<TableFunctionArgument> arguments,
+ List<ParameterSpecification> parameterSpecifications) {
+ if
(TableBuiltinTableFunction.FORECAST.getFunctionName().equalsIgnoreCase(functionName))
{
+ int position = -1;
+ String timeColumn = null;
+ for (int i = 0, size = parameterSpecifications.size(); i < size; i++) {
+ if (ForecastTableFunction.TIMECOL_PARAMETER_NAME.equalsIgnoreCase(
+ parameterSpecifications.get(i).getName())) {
+ position = i;
+ timeColumn = (String)
parameterSpecifications.get(i).getDefaultValue().get();
+ break;
+ }
+ }
+ if (position == -1) {
+ throw new IllegalStateException(
+ "ForecastTableFunction must contain
ForecastTableFunction.TIMECOL_PARAMETER_NAME");
+ }
+ if (position < arguments.size()
+ && arguments.get(position).getValue() instanceof StringLiteral) {
+ timeColumn = ((StringLiteral)
arguments.get(position).getValue()).getValue();
+ }
+ tryUpdateOrderByForForecast(arguments, timeColumn);
+ }
+ }
+
+ // append order by time asc for built-in forecast tvf if user doesn't
specify order by clause
+ private void tryUpdateOrderByForForecast(
+ List<TableFunctionArgument> arguments, String timeColumn) {
+ for (TableFunctionArgument argument : arguments) {
+ if (argument.getValue() instanceof TableFunctionTableArgument) {
+ TableFunctionTableArgument input = (TableFunctionTableArgument)
argument.getValue();
+ if (!input.getOrderBy().isPresent()) {
+ input.updateOrderBy(
+ new OrderBy(
+ Collections.singletonList(
+ new SortItem(
+ new Identifier(timeColumn),
+ SortItem.Ordering.ASCENDING,
+ SortItem.NullOrdering.FIRST))));
+ }
+ }
+ }
+ }
+
private ArgumentAnalysis analyzeArgument(
ParameterSpecification parameterSpecification,
TableFunctionArgument argument,
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
index 35dbed2ac9f..dbe91526bd8 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java
@@ -27,7 +27,6 @@ import
org.apache.iotdb.commons.client.ainode.AINodeClientManager;
import org.apache.iotdb.commons.exception.IoTDBRuntimeException;
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
-import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.rpc.TSStatusCode;
import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -65,6 +64,7 @@ import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
import java.util.Set;
@@ -73,7 +73,7 @@ import static
org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE;
public class ForecastTableFunction implements TableFunction {
- private static class ForecastTableFunctionHandle implements
TableFunctionHandle {
+ public static class ForecastTableFunctionHandle implements
TableFunctionHandle {
TEndPoint targetAINode;
String modelId;
int maxInputLength;
@@ -152,9 +152,41 @@ public class ForecastTableFunction implements
TableFunction {
types.add(Type.valueOf(ReadWriteIOUtils.readByte(buffer)));
}
}
- }
- private static final IModelFetcher MODEL_FETCHER =
ModelFetcher.getInstance();
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ForecastTableFunctionHandle that = (ForecastTableFunctionHandle) o;
+ return maxInputLength == that.maxInputLength
+ && outputLength == that.outputLength
+ && outputStartTime == that.outputStartTime
+ && outputInterval == that.outputInterval
+ && keepInput == that.keepInput
+ && Objects.equals(targetAINode, that.targetAINode)
+ && Objects.equals(modelId, that.modelId)
+ && Objects.equals(options, that.options)
+ && Objects.equals(types, that.types);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ targetAINode,
+ modelId,
+ maxInputLength,
+ outputLength,
+ outputStartTime,
+ outputInterval,
+ keepInput,
+ options,
+ types);
+ }
+ }
private static final String INPUT_PARAMETER_NAME = "INPUT";
private static final String MODEL_ID_PARAMETER_NAME = "MODEL_ID";
@@ -163,15 +195,15 @@ public class ForecastTableFunction implements
TableFunction {
private static final String PREDICATED_COLUMNS_PARAMETER_NAME =
"PREDICATED_COLUMNS";
private static final String DEFAULT_PREDICATED_COLUMNS = "";
private static final String OUTPUT_START_TIME = "OUTPUT_START_TIME";
- private static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
+ public static final long DEFAULT_OUTPUT_START_TIME = Long.MIN_VALUE;
private static final String OUTPUT_INTERVAL = "OUTPUT_INTERVAL";
- private static final long DEFAULT_OUTPUT_INTERVAL = 0L;
- private static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
+ public static final long DEFAULT_OUTPUT_INTERVAL = 0L;
+ public static final String TIMECOL_PARAMETER_NAME = "TIMECOL";
private static final String DEFAULT_TIME_COL = "time";
private static final String KEEP_INPUT_PARAMETER_NAME = "KEEP_INPUT";
private static final Boolean DEFAULT_KEEP_INPUT = Boolean.FALSE;
private static final String IS_INPUT_COLUMN_NAME = "is_input";
- private static final String OPTIONS_PARAMETER_NAME = "OPTIONS";
+ private static final String OPTIONS_PARAMETER_NAME = "MODEL_OPTIONS";
private static final String DEFAULT_OPTIONS = "";
private static final String INVALID_OPTIONS_FORMAT = "Invalid options: %s";
@@ -185,6 +217,16 @@ public class ForecastTableFunction implements
TableFunction {
ALLOWED_INPUT_TYPES.add(Type.DOUBLE);
}
+ // need to set before analyze method is called
+ // should only be used in fe scope, never be used in
TableFunctionProcessorProvider
+ // The reason we don't directly set modelFetcher=ModelFetcher.getInstance()
is that we need to
+ // mock IModelFetcher in UT
+ private IModelFetcher modelFetcher = null;
+
+ public void setModelFetcher(IModelFetcher modelFetcher) {
+ this.modelFetcher = modelFetcher;
+ }
+
@Override
public List<ParameterSpecification> getArgumentsSpecifications() {
return Arrays.asList(
@@ -372,7 +414,7 @@ public class ForecastTableFunction implements TableFunction
{
}
private ModelInferenceDescriptor getModelInfo(String modelId) {
- return MODEL_FETCHER.fetchModel(modelId);
+ return modelFetcher.fetchModel(modelId);
}
// only allow for INT32, INT64, FLOAT, DOUBLE
@@ -501,10 +543,19 @@ public class ForecastTableFunction implements
TableFunction {
// time column
long inputStartTime = inputRecords.getFirst().getLong(0);
long inputEndTime = inputRecords.getLast().getLong(0);
- long interval =
- outputInterval <= 0
- ? (inputEndTime - inputStartTime) / (inputRecords.size() - 1)
- : outputInterval;
+ if (inputEndTime < inputStartTime) {
+ throw new SemanticException(
+ String.format(
+ "input end time should never less than start time, start time
is %s, end time is %s",
+ inputStartTime, inputEndTime));
+ }
+ long interval = outputInterval;
+ if (outputInterval <= 0) {
+ interval =
+ inputRecords.size() == 1
+ ? 0
+ : (inputEndTime - inputStartTime) / (inputRecords.size() - 1);
+ }
long outputTime =
(outputStartTime == Long.MIN_VALUE) ? (inputEndTime + interval) :
outputStartTime;
for (int i = 0; i < outputLength; i++) {
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java
index 443fbd6e3f0..4d4c160cd50 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java
@@ -26,6 +26,7 @@ import
org.apache.iotdb.db.exception.load.LoadAnalyzeTableColumnDisorderExceptio
import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher;
import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType;
import org.apache.iotdb.db.queryengine.plan.relational.security.AccessControl;
@@ -194,4 +195,9 @@ public interface Metadata {
final String database, final List<DataPartitionQueryParam>
sgNameToQueryParamsMap);
TableFunction getTableFunction(final String functionName);
+
+ /**
+ * @return ModelFetcher
+ */
+ IModelFetcher getModelFetcher();
}
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
index 742f980e2e8..73703c8d4e2 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java
@@ -34,7 +34,9 @@ import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.plan.analyze.ClusterPartitionFetcher;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher;
+import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher;
import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType;
import
org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction;
import
org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.AdditionResolver;
@@ -95,6 +97,8 @@ public class TableMetadataImpl implements Metadata {
private final DataNodeTableCache tableCache =
DataNodeTableCache.getInstance();
+ private final IModelFetcher modelFetcher = ModelFetcher.getInstance();
+
@Override
public boolean tableExists(final QualifiedObjectName name) {
return tableCache.getTable(name.getDatabaseName(), name.getObjectName())
!= null;
@@ -841,6 +845,11 @@ public class TableMetadataImpl implements Metadata {
}
}
+ @Override
+ public IModelFetcher getModelFetcher() {
+ return modelFetcher;
+ }
+
public static boolean isTwoNumericType(List<? extends Type> argumentTypes) {
return argumentTypes.size() == 2
&& isNumericType(argumentTypes.get(0))
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
index 8a0606ccfe7..d99fa3da1e3 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
@@ -273,27 +273,50 @@ public class TableDistributedPlanGenerator
node.getChildren().size() == 1, "Size of TopKNode can only be 1 in
logical plan.");
List<PlanNode> childrenNodes = node.getChildren().get(0).accept(this,
context);
if (childrenNodes.size() == 1) {
+ if (canTopKEliminated(node.getOrderingScheme(), node.getCount(),
childrenNodes.get(0))) {
+ return childrenNodes;
+ }
node.setChildren(Collections.singletonList(childrenNodes.get(0)));
return Collections.singletonList(node);
}
TopKNode newTopKNode = (TopKNode) node.clone();
for (PlanNode child : childrenNodes) {
- TopKNode subTopKNode =
- new TopKNode(
- queryId.genPlanNodeId(),
- Collections.singletonList(child),
- node.getOrderingScheme(),
- node.getCount(),
- node.getOutputSymbols(),
- node.isChildrenDataInOrder());
- newTopKNode.addChild(subTopKNode);
+ PlanNode newChild;
+ if (canTopKEliminated(node.getOrderingScheme(), node.getCount(), child))
{
+ newChild = child;
+ } else {
+ newChild =
+ new TopKNode(
+ queryId.genPlanNodeId(),
+ Collections.singletonList(child),
+ node.getOrderingScheme(),
+ node.getCount(),
+ node.getOutputSymbols(),
+ node.isChildrenDataInOrder());
+ }
+ newTopKNode.addChild(newChild);
}
nodeOrderingMap.put(newTopKNode.getPlanNodeId(),
newTopKNode.getOrderingScheme());
return Collections.singletonList(newTopKNode);
}
+ // if DeviceTableScanNode has limit <= K and with same order, we can
eliminate TopK
+ private boolean canTopKEliminated(OrderingScheme orderingScheme, long k,
PlanNode child) {
+ // if DeviceTableScanNode has limit <= K and with same order, we can
directly return
+ // DeviceTableScanNode
+ if (child instanceof DeviceTableScanNode) {
+ DeviceTableScanNode tableScanNode = (DeviceTableScanNode) child;
+ return k >= tableScanNode.getPushDownLimit()
+ && (!tableScanNode.isPushLimitToEachDevice()
+ || (tableScanNode.isPushLimitToEachDevice()
+ && tableScanNode.getDeviceEntries().size() == 1))
+ && canSortEliminated(orderingScheme,
nodeOrderingMap.get(child.getPlanNodeId()));
+ }
+ return false;
+ }
+
@Override
public List<PlanNode> visitGroup(GroupNode node, PlanContext context) {
context.setExpectedOrderingScheme(node.getOrderingScheme());
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/GroupNode.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/GroupNode.java
index 80ea7657fff..83b034ac7bf 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/GroupNode.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/GroupNode.java
@@ -49,7 +49,7 @@ public class GroupNode extends SortNode {
* orderingScheme may include two parts: PartitionKey and OrderKey. It marks
the number of
* PartitionKey.
*/
- private int partitionKeyCount;
+ private final int partitionKeyCount;
public GroupNode(PlanNodeId id, PlanNode child, OrderingScheme scheme, int
partitionKeyCount) {
super(id, child, scheme, false, false);
diff --git
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/TableFunctionTableArgument.java
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/TableFunctionTableArgument.java
index 8c0addef1d8..c6f1ee88d0c 100644
---
a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/TableFunctionTableArgument.java
+++
b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/TableFunctionTableArgument.java
@@ -32,7 +32,7 @@ import static
org.apache.iotdb.db.queryengine.plan.relational.sql.util.Expressio
public class TableFunctionTableArgument extends Node {
private final Relation table;
private final Optional<List<Expression>> partitionBy; // it is allowed to
partition by empty list
- private final Optional<OrderBy> orderBy;
+ private Optional<OrderBy> orderBy;
public TableFunctionTableArgument(
NodeLocation location,
@@ -57,6 +57,10 @@ public class TableFunctionTableArgument extends Node {
return orderBy;
}
+ public void updateOrderBy(OrderBy orderBy) {
+ this.orderBy = Optional.of(orderBy);
+ }
+
@Override
public <R, C> R accept(AstVisitor<R, C> visitor, C context) {
return visitor.visitTableArgument(this, context);
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
index 189f86854c1..159ab2486a9 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java
@@ -28,6 +28,7 @@ import
org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher;
import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType;
import
org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry;
@@ -388,6 +389,11 @@ public class TSBSMetadata implements Metadata {
return null;
}
+ @Override
+ public IModelFetcher getModelFetcher() {
+ return null;
+ }
+
private static final DataPartition DATA_PARTITION =
MockTSBSDataPartition.constructDataPartition();
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
index 3aa3d85fa80..9dd16653a89 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java
@@ -19,7 +19,9 @@
package org.apache.iotdb.db.queryengine.plan.relational.analyzer;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction;
import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.TableFunctionProcessorMatcher;
@@ -32,8 +34,11 @@ import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.junit.Test;
+import java.util.Collections;
import java.util.function.Consumer;
+import static
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction.DEFAULT_OUTPUT_INTERVAL;
+import static
org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction.DEFAULT_OUTPUT_START_TIME;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregation;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.aggregationFunction;
@@ -48,6 +53,12 @@ import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.sort;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.tableFunctionProcessor;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.tableScan;
+import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.topK;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SortItem.NullOrdering.FIRST;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SortItem.NullOrdering.LAST;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SortItem.Ordering.ASCENDING;
+import static
org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SortItem.Ordering.DESCENDING;
+import static org.apache.iotdb.udf.api.type.Type.DOUBLE;
public class TableFunctionTest {
@@ -332,4 +343,65 @@ public class TableFunctionTest {
deserialized.deserialize(serialized);
assert mapTableFunctionHandle.equals(deserialized);
}
+
+ @Test
+ public void testForecastFunction() {
+ // default order by time asc
+ PlanTester planTester = new PlanTester();
+
+ String sql =
+ "SELECT * FROM FORECAST("
+ + "input => (SELECT time,s3 FROM table1 WHERE tag1='shanghai' AND
tag2='A3' AND tag3='YY' ORDER BY time DESC LIMIT 1440), "
+ + "model_id => 'timer_xl')";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+
+ PlanMatchPattern tableScan =
+ tableScan("testdb.table1", ImmutableMap.of("time_0", "time", "s3_1",
"s3"));
+ Consumer<TableFunctionProcessorMatcher.Builder> tableFunctionMatcher =
+ builder ->
+ builder
+ .name("forecast")
+ .properOutputs("time", "s3")
+ .requiredSymbols("time_0", "s3_1")
+ .handle(
+ new ForecastTableFunction.ForecastTableFunctionHandle(
+ false,
+ 1440,
+ "timer_xl",
+ Collections.emptyMap(),
+ 96,
+ DEFAULT_OUTPUT_START_TIME,
+ DEFAULT_OUTPUT_INTERVAL,
+ new TEndPoint("127.0.0.1", 10810),
+ Collections.singletonList(DOUBLE)));
+ // Verify full LogicalPlan
+ // Output - TableFunctionProcessor - TableScan
+ assertPlan(
+ logicalQueryPlan,
+ anyTree(
+ tableFunctionProcessor(
+ tableFunctionMatcher,
+ group(
+ ImmutableList.of(sort("time_0", ASCENDING, FIRST)),
+ 0,
+ topK(
+ 1440,
+ ImmutableList.of(sort("time_0", DESCENDING, LAST)),
+ false,
+ tableScan)))));
+ // Verify DistributionPlan
+
+ /*
+ * └──OutputNode
+ * └──TableFunctionProcessor
+ * └──GroupNode
+ * └──TableScan
+ */
+ assertPlan(
+ planTester.getFragmentPlan(0),
+ output(
+ tableFunctionProcessor(
+ tableFunctionMatcher,
+ group(ImmutableList.of(sort("time_0", ASCENDING, FIRST)), 0,
tableScan))));
+ }
}
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
index 3292a5077cd..2fc2e7989b8 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java
@@ -19,6 +19,8 @@
package org.apache.iotdb.db.queryengine.plan.relational.analyzer;
+import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.commons.model.ModelInformation;
import org.apache.iotdb.commons.partition.DataPartition;
import org.apache.iotdb.commons.partition.DataPartitionQueryParam;
import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition;
@@ -27,14 +29,17 @@ import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.commons.schema.table.TsTable;
import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
import org.apache.iotdb.commons.udf.builtin.BuiltinAggregationFunction;
-import org.apache.iotdb.commons.udf.builtin.relational.tvf.HOPTableFunction;
+import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.common.MPPQueryContext;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
+import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher;
import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher;
import org.apache.iotdb.db.queryengine.plan.function.Exclude;
import org.apache.iotdb.db.queryengine.plan.function.Repeat;
import org.apache.iotdb.db.queryengine.plan.function.Split;
+import
org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor;
import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType;
+import
org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction;
import
org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnMetadata;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema;
@@ -55,6 +60,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.type.InternalTypeManager;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeManager;
import
org.apache.iotdb.db.queryengine.plan.relational.type.TypeNotFoundException;
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeSignature;
+import org.apache.iotdb.db.queryengine.plan.udf.TableUDFUtils;
import org.apache.iotdb.db.schemaengine.table.InformationSchemaUtils;
import org.apache.iotdb.mpp.rpc.thrift.TRegionRouteReq;
import org.apache.iotdb.udf.api.relational.TableFunction;
@@ -311,7 +317,17 @@ public class TestMetadata implements Metadata {
IDeviceID.Factory.DEFAULT_FACTORY.create(DEVICE_6), new
Binary[0])));
}
- if (expressionList.size() == 2) {
+ if (expressionList.size() == 3) {
+ if (compareEqualsMatch(expressionList.get(0), "tag1", "shanghai")
+ && compareEqualsMatch(expressionList.get(1), "tag2", "A3")
+ && compareEqualsMatch(expressionList.get(2), "tag3", "YY")) {
+ return Collections.singletonMap(
+ DB1,
+ Collections.singletonList(
+ new AlignedDeviceEntry(
+ new StringArrayDeviceID(DEVICE_3.split("\\.")),
DEVICE_1_ATTRIBUTES)));
+ }
+ } else if (expressionList.size() == 2) {
if (compareEqualsMatch(expressionList.get(0), "tag1", "beijing")
&& compareEqualsMatch(expressionList.get(1), "tag2", "A1")
|| compareEqualsMatch(expressionList.get(1), "tag1", "beijing")
@@ -478,19 +494,38 @@ public class TestMetadata implements Metadata {
@Override
public TableFunction getTableFunction(String functionName) {
- if ("HOP".equalsIgnoreCase(functionName)) {
- return new HOPTableFunction();
- } else if ("EXCLUDE".equalsIgnoreCase(functionName)) {
+ if ("EXCLUDE".equalsIgnoreCase(functionName)) {
return new Exclude();
} else if ("REPEAT".equalsIgnoreCase(functionName)) {
return new Repeat();
} else if ("SPLIT".equalsIgnoreCase(functionName)) {
return new Split();
} else {
- return null;
+ if (TableBuiltinTableFunction.isBuiltInTableFunction(functionName)) {
+ return TableBuiltinTableFunction.getBuiltinTableFunction(functionName);
+ } else if (TableUDFUtils.isTableFunction(functionName)) {
+ return TableUDFUtils.getTableFunction(functionName);
+ } else {
+ throw new SemanticException("Unknown function: " + functionName);
+ }
}
}
+ @Override
+ public IModelFetcher getModelFetcher() {
+ String modelId = "timer_xl";
+ IModelFetcher fetcher = Mockito.mock(IModelFetcher.class);
+ ModelInferenceDescriptor descriptor =
Mockito.mock(ModelInferenceDescriptor.class);
+ Mockito.when(descriptor.getTargetAINode()).thenReturn(new
TEndPoint("127.0.0.1", 10810));
+ ModelInformation modelInformation = Mockito.mock(ModelInformation.class);
+ Mockito.when(modelInformation.available()).thenReturn(true);
+ Mockito.when(modelInformation.getInputShape()).thenReturn(new int[] {1440,
96});
+
Mockito.when(descriptor.getModelInformation()).thenReturn(modelInformation);
+ Mockito.when(descriptor.getModelName()).thenReturn(modelId);
+ Mockito.when(fetcher.fetchModel(modelId)).thenReturn(descriptor);
+ return fetcher;
+ }
+
private static final DataPartition TABLE_DATA_PARTITION =
MockTableModelDataPartition.constructDataPartition(DB1);
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/GroupMatcher.java
similarity index 60%
copy from
iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
copy to
iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/GroupMatcher.java
index fe81ae95d65..ff17bdfec3c 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/GroupMatcher.java
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
@@ -22,47 +22,45 @@ package
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
-import
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.Ordering;
-import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.GroupNode;
import java.util.List;
import static com.google.common.base.MoreObjects.toStringHelper;
-import static com.google.common.base.Preconditions.checkState;
-import static java.util.Objects.requireNonNull;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.MatchResult.NO_MATCH;
-import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.Util.orderingSchemeMatches;
-final class SortMatcher implements Matcher {
- private final List<Ordering> orderBy;
+public class GroupMatcher extends SortMatcher {
+ private final int partitionKeyCount;
- public SortMatcher(List<Ordering> orderBy) {
- this.orderBy = requireNonNull(orderBy, "orderBy is null");
+ public GroupMatcher(List<PlanMatchPattern.Ordering> orderBy, int
partitionKeyCount) {
+ super(orderBy);
+ this.partitionKeyCount = partitionKeyCount;
}
@Override
public boolean shapeMatches(PlanNode node) {
- return node instanceof SortNode;
+ return node instanceof GroupNode;
}
@Override
public MatchResult detailMatches(
PlanNode node, SessionInfo sessionInfo, Metadata metadata, SymbolAliases
symbolAliases) {
- checkState(
- shapeMatches(node),
- "Plan testing framework error: shapeMatches returned false in
detailMatches in %s",
- this.getClass().getName());
- SortNode sortNode = (SortNode) node;
-
- if (!orderingSchemeMatches(orderBy, sortNode.getOrderingScheme(),
symbolAliases)) {
- return NO_MATCH;
+ MatchResult result = super.detailMatches(node, sessionInfo, metadata,
symbolAliases);
+ if (result != NO_MATCH) {
+ GroupNode sortNode = (GroupNode) node;
+ if (partitionKeyCount != ((GroupNode) node).getPartitionKeyCount()) {
+ return NO_MATCH;
+ }
+ return MatchResult.match();
}
-
- return MatchResult.match();
+ return NO_MATCH;
}
@Override
public String toString() {
- return toStringHelper(this).add("orderBy", orderBy).toString();
+ return toStringHelper(this)
+ .add("orderBy", orderBy)
+ .add("partitionKeyCount", partitionKeyCount)
+ .toString();
}
}
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/PlanMatchPattern.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/PlanMatchPattern.java
index ef051873c35..93662f6082b 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/PlanMatchPattern.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/PlanMatchPattern.java
@@ -49,6 +49,7 @@ import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.SemiJoinNode
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.StreamSortNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TableFunctionProcessorNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TreeAlignedDeviceViewScanNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TreeDeviceViewScanNode;
import
org.apache.iotdb.db.queryengine.plan.relational.planner.node.TreeNonAlignedDeviceViewScanNode;
@@ -545,6 +546,11 @@ public final class PlanMatchPattern {
return node(GroupNode.class, source);
}
+ public static PlanMatchPattern group(
+ List<Ordering> orderBy, int partitionKeyCount, PlanMatchPattern source) {
+ return node(GroupNode.class, source).with(new GroupMatcher(orderBy,
partitionKeyCount));
+ }
+
public static PlanMatchPattern sort(PlanMatchPattern source) {
return node(SortNode.class, source);
}
@@ -557,12 +563,12 @@ public final class PlanMatchPattern {
return node(StreamSortNode.class, source).with(new SortMatcher(orderBy));
}
- /*public static PlanMatchPattern topN(long count, List<Ordering> orderBy,
PlanMatchPattern source)
- {
- return topN(count, orderBy, TopNNode.Step.SINGLE, source);
+ public static PlanMatchPattern topK(
+ long count, List<Ordering> orderBy, boolean childrenDataInOrder,
PlanMatchPattern source) {
+ return node(TopKNode.class, source).with(new TopKMatcher(orderBy, count,
childrenDataInOrder));
}
- public static PlanMatchPattern topN(long count, List<Ordering> orderBy,
TopNNode.Step step, PlanMatchPattern source)
+ /*public static PlanMatchPattern topN(long count, List<Ordering> orderBy,
TopNNode.Step step, PlanMatchPattern source)
{
return node(TopNNode.class, source).with(new TopNMatcher(count, orderBy,
step));
}*/
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
index fe81ae95d65..3a95f579b31 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
@@ -33,8 +33,8 @@ import static java.util.Objects.requireNonNull;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.MatchResult.NO_MATCH;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.Util.orderingSchemeMatches;
-final class SortMatcher implements Matcher {
- private final List<Ordering> orderBy;
+class SortMatcher implements Matcher {
+ protected final List<Ordering> orderBy;
public SortMatcher(List<Ordering> orderBy) {
this.orderBy = requireNonNull(orderBy, "orderBy is null");
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TableScanMatcher.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TableScanMatcher.java
index 01ca624a423..b860a18c16a 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TableScanMatcher.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TableScanMatcher.java
@@ -64,7 +64,6 @@ public abstract class TableScanMatcher implements Matcher {
TableScanNode tableScanNode = (TableScanNode) node;
String actualTableName = tableScanNode.getQualifiedObjectName().toString();
- // TODO (https://github.com/trinodb/trino/issues/17) change to equals()
if (!expectedTableName.equalsIgnoreCase(actualTableName)) {
return NO_MATCH;
}
diff --git
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TopKMatcher.java
similarity index 68%
copy from
iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
copy to
iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TopKMatcher.java
index fe81ae95d65..018a693e55c 100644
---
a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/SortMatcher.java
+++
b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/assertions/TopKMatcher.java
@@ -7,7 +7,7 @@
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
- * http://www.apache.org/licenses/LICENSE-2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
@@ -22,27 +22,30 @@ package
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions;
import org.apache.iotdb.db.queryengine.common.SessionInfo;
import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata;
-import
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.Ordering;
-import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKNode;
import java.util.List;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
-import static java.util.Objects.requireNonNull;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.MatchResult.NO_MATCH;
import static
org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.Util.orderingSchemeMatches;
-final class SortMatcher implements Matcher {
- private final List<Ordering> orderBy;
+public class TopKMatcher implements Matcher {
+ private final List<PlanMatchPattern.Ordering> orderBy;
+ private final long count;
+ private final boolean childrenDataInOrder;
- public SortMatcher(List<Ordering> orderBy) {
- this.orderBy = requireNonNull(orderBy, "orderBy is null");
+ public TopKMatcher(
+ List<PlanMatchPattern.Ordering> orderBy, long count, boolean
childrenDataInOrder) {
+ this.orderBy = orderBy;
+ this.count = count;
+ this.childrenDataInOrder = childrenDataInOrder;
}
@Override
public boolean shapeMatches(PlanNode node) {
- return node instanceof SortNode;
+ return node instanceof TopKNode;
}
@Override
@@ -52,9 +55,11 @@ final class SortMatcher implements Matcher {
shapeMatches(node),
"Plan testing framework error: shapeMatches returned false in
detailMatches in %s",
this.getClass().getName());
- SortNode sortNode = (SortNode) node;
+ TopKNode topKNode = (TopKNode) node;
- if (!orderingSchemeMatches(orderBy, sortNode.getOrderingScheme(),
symbolAliases)) {
+ if (!orderingSchemeMatches(orderBy, topKNode.getOrderingScheme(),
symbolAliases)
+ || count != topKNode.getCount()
+ || childrenDataInOrder != topKNode.isChildrenDataInOrder()) {
return NO_MATCH;
}
@@ -63,6 +68,10 @@ final class SortMatcher implements Matcher {
@Override
public String toString() {
- return toStringHelper(this).add("orderBy", orderBy).toString();
+ return toStringHelper(this)
+ .add("orderBy", orderBy)
+ .add("count", count)
+ .add("childrenDataInOrder", childrenDataInOrder)
+ .toString();
}
}