This is an automated email from the ASF dual-hosted git repository.
chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new fd0c05c632a fix:Support for the PREDICT function (#36268)
fd0c05c632a is described below
commit fd0c05c632a87b9bd3bcfc16a7121527ae2c47b2
Author: cxy <[email protected]>
AuthorDate: Wed Aug 13 09:22:22 2025 +0800
fix:Support for the PREDICT function (#36268)
---
.../src/main/antlr4/imports/sqlserver/BaseRule.g4 | 14 ++++++++-
.../main/antlr4/imports/sqlserver/DMLStatement.g4 | 2 +-
.../antlr4/imports/sqlserver/SQLServerKeyword.g4 | 12 ++++++++
.../statement/SQLServerStatementVisitor.java | 36 ++++++++++++++++++++++
.../parser/src/main/resources/case/dml/select.xml | 25 +++++++++++++++
.../main/resources/sql/supported/dml/select.xml | 1 +
6 files changed, 88 insertions(+), 2 deletions(-)
diff --git
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
index 63a53705de6..bfff4f7ed5c 100644
--- a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
+++ b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
@@ -481,7 +481,19 @@ openQueryFunction
;
rowSetFunction
- : openRowSetFunction | openQueryFunction | openDatasourceFunction
+ : openRowSetFunction | openQueryFunction | openDatasourceFunction |
predictFunction
+ ;
+
+predictFunction
+ : PREDICT LP_ MODEL EQ_ (variableName | literals) COMMA_ DATA EQ_
tableName (AS alias)? (COMMA_ RUNTIME EQ_ ONNX)? RP_ WITH LP_
predictResultSetDefinition RP_
+ ;
+
+predictResultSetDefinition
+ : predictColumnDefinition (COMMA_ predictColumnDefinition)*
+ ;
+
+predictColumnDefinition
+ : columnName dataType (COLLATE collationName)? (NULL | NOT NULL)?
;
regularFunction
diff --git
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
index afc7af9bc90..92792f7756c 100644
---
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
+++
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
@@ -185,7 +185,7 @@ tableReference
;
tableFactor
- : tableName (FOR PATH)? forSystemTimeClause? (AS? alias)?
tableSampleClause? withTableHint? | subquery AS? alias columnNames? | expr (AS?
alias)? columnNames? | xmlMethodCall (AS? alias)? columnNames? | LP_
tableReferences RP_ | pivotTable
+ : tableName (FOR PATH)? forSystemTimeClause? (AS? alias)?
tableSampleClause? withTableHint? | subquery AS? alias columnNames? |
rowSetFunction (AS? alias)? | expr (AS? alias)? columnNames? | xmlMethodCall
(AS? alias)? columnNames? | LP_ tableReferences RP_ | pivotTable
;
pivotTable
diff --git
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
index 9028275ba3d..d7f81f99e72 100644
---
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
+++
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
@@ -2170,3 +2170,15 @@ SETS
DISTRIBUTED_AGG
: D I S T R I B U T E D UL_ A G G
;
+
+PREDICT
+ : P R E D I C T
+ ;
+
+RUNTIME
+ : R U N T I M E
+ ;
+
+ONNX
+ : O N N X
+ ;
diff --git
a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
index bb318aa1b47..d7e1bdebef6 100644
---
a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
+++
b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
@@ -164,6 +164,7 @@ import
org.apache.shardingsphere.sql.parser.statement.core.enums.ParameterMarker
import org.apache.shardingsphere.sql.parser.statement.core.enums.ScanUnit;
import
org.apache.shardingsphere.sql.parser.statement.core.enums.StatisticsDimension;
import org.apache.shardingsphere.sql.parser.statement.core.segment.SQLSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dal.VariableSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.constraint.ConstraintSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexNameSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexSegment;
@@ -260,6 +261,8 @@ import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.Rol
import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.GroupingSetsClauseContext;
import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.GroupingExprListContext;
import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.ExpressionListContext;
+import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.VariableNameContext;
+import
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.PredictFunctionContext;
import java.util.Collection;
import java.util.Collections;
@@ -347,6 +350,11 @@ public abstract class SQLServerStatementVisitor extends
SQLServerStatementBaseVi
return null == ctx.regularIdentifier() ?
visit(ctx.delimitedIdentifier()) : visit(ctx.regularIdentifier());
}
+ @Override
+ public final ASTNode visitVariableName(final VariableNameContext ctx) {
+ return new VariableSegment(ctx.getStart().getStartIndex(),
ctx.getStop().getStopIndex(), ((IdentifierValue)
visit(ctx.identifier())).getValue());
+ }
+
@Override
public final ASTNode visitRegularIdentifier(final RegularIdentifierContext
ctx) {
UnreservedWordContext unreservedWord = ctx.unreservedWord();
@@ -1468,11 +1476,31 @@ public abstract class SQLServerStatementVisitor extends
SQLServerStatementBaseVi
return visit(ctx.openRowSetFunction());
} else if (null != ctx.openQueryFunction()) {
return visit(ctx.openQueryFunction());
+ } else if (null != ctx.predictFunction()) {
+ return visit(ctx.predictFunction());
} else {
return visit(ctx.openDatasourceFunction());
}
}
+ @Override
+ public ASTNode visitPredictFunction(final PredictFunctionContext ctx) {
+ FunctionSegment result = new
FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(),
ctx.PREDICT().getText(), getOriginalText(ctx));
+ if (null != ctx.variableName()) {
+ result.getParameters().add((ExpressionSegment)
visit(ctx.variableName()));
+ } else if (null != ctx.literals()) {
+ result.getParameters().add((ExpressionSegment)
visit(ctx.literals()));
+ }
+ result.getParameters().add(new
LiteralExpressionSegment(ctx.tableName().getStart().getStartIndex(),
ctx.tableName().getStop().getStopIndex(), ctx.tableName().getText()));
+ if (null != ctx.alias()) {
+ result.getParameters().add(new
LiteralExpressionSegment(ctx.alias().getStart().getStartIndex(),
ctx.alias().getStop().getStopIndex(), ctx.alias().getText()));
+ }
+ if (null != ctx.ONNX()) {
+ result.getParameters().add(new
LiteralExpressionSegment(ctx.ONNX().getSymbol().getStartIndex(),
ctx.ONNX().getSymbol().getStopIndex(), ctx.ONNX().getText()));
+ }
+ return result;
+ }
+
@Override
public ASTNode visitWithTableHint(final WithTableHintContext ctx) {
WithTableHintSegment result = new
WithTableHintSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex());
@@ -1939,6 +1967,14 @@ public abstract class SQLServerStatementVisitor extends
SQLServerStatementBaseVi
if (null != ctx.pivotTable()) {
return visit(ctx.pivotTable());
}
+ if (null != ctx.rowSetFunction()) {
+ FunctionSegment functionSegment = (FunctionSegment)
visit(ctx.rowSetFunction());
+ FunctionTableSegment result = new
FunctionTableSegment(functionSegment.getStartIndex(),
functionSegment.getStopIndex(), functionSegment);
+ if (null != ctx.alias()) {
+ result.setAlias((AliasSegment) visit(ctx.alias()));
+ }
+ return result;
+ }
return visit(ctx.tableReferences());
}
diff --git a/test/it/parser/src/main/resources/case/dml/select.xml
b/test/it/parser/src/main/resources/case/dml/select.xml
index e0a724ca4fe..d70e3c0624f 100644
--- a/test/it/parser/src/main/resources/case/dml/select.xml
+++ b/test/it/parser/src/main/resources/case/dml/select.xml
@@ -11652,4 +11652,29 @@
<column-item name="CustomerKey" start-index="51" stop-index="61" />
</group-by>
</select>
+
+ <select sql-case-id="select_with_predict_function">
+ <projections start-index="7" stop-index="18">
+ <shorthand-projection start-index="7" stop-index="9">
+ <owner name="d" start-index="7" stop-index="7" />
+ </shorthand-projection>
+ <column-projection name="Score" start-index="12" stop-index="18">
+ <owner name="p" start-index="12" stop-index="12" />
+ </column-projection>
+ </projections>
+ <from start-index="26" stop-index="96">
+ <function-table table-alias="p" start-index="26" stop-index="91">
+ <table-function function-name="PREDICT" text="PREDICT(MODEL =
@model, DATA = dbo.mytable AS d) WITH (Score FLOAT)">
+ <parameter>
+ <parameter-marker value="@model" start-index="40"
stop-index="45" />
+ </parameter>
+ <parameter>
+ <simple-table name="mytable" alias="d"
start-index="54" stop-index="70">
+ <owner name="dbo" start-index="54" stop-index="54"
/>
+ </simple-table>
+ </parameter>
+ </table-function>
+ </function-table>
+ </from>
+ </select>
</sql-parser-test-cases>
diff --git a/test/it/parser/src/main/resources/sql/supported/dml/select.xml
b/test/it/parser/src/main/resources/sql/supported/dml/select.xml
index 3b460069f32..0c624a2d463 100644
--- a/test/it/parser/src/main/resources/sql/supported/dml/select.xml
+++ b/test/it/parser/src/main/resources/sql/supported/dml/select.xml
@@ -367,4 +367,5 @@
<sql-case id="select_grouping_sets" value="SELECT Country, Region,
SUM(Sales) AS TotalSales FROM Sales GROUP BY GROUPING SETS ( ROLLUP (Country,
Region), CUBE (Country, Region));" db-types="SQLServer"/>
<sql-case id="select_identity_function" value="SELECT IDENTITY(int,1,1) AS
ID_Num INTO NewTable FROM OldTable;" db-types="SQLServer"/>
<sql-case id="select_group_by_with_distributed_agg" value="SELECT
CustomerKey FROM FactInternetSales GROUP BY CustomerKey WITH (DISTRIBUTED_AGG)"
db-types="SQLServer"/>
+ <sql-case id="select_with_predict_function" value="SELECT d.*, p.Score
FROM PREDICT(MODEL = @model, DATA = dbo.mytable AS d) WITH (Score FLOAT) AS p;"
db-types="SQLServer"/>
</sql-cases>