This is an automated email from the ASF dual-hosted git repository.

chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new fd0c05c632a fix:Support for the PREDICT function (#36268)
fd0c05c632a is described below

commit fd0c05c632a87b9bd3bcfc16a7121527ae2c47b2
Author: cxy <[email protected]>
AuthorDate: Wed Aug 13 09:22:22 2025 +0800

    fix:Support for the PREDICT function (#36268)
---
 .../src/main/antlr4/imports/sqlserver/BaseRule.g4  | 14 ++++++++-
 .../main/antlr4/imports/sqlserver/DMLStatement.g4  |  2 +-
 .../antlr4/imports/sqlserver/SQLServerKeyword.g4   | 12 ++++++++
 .../statement/SQLServerStatementVisitor.java       | 36 ++++++++++++++++++++++
 .../parser/src/main/resources/case/dml/select.xml  | 25 +++++++++++++++
 .../main/resources/sql/supported/dml/select.xml    |  1 +
 6 files changed, 88 insertions(+), 2 deletions(-)

diff --git 
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4 
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
index 63a53705de6..bfff4f7ed5c 100644
--- a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
+++ b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/BaseRule.g4
@@ -481,7 +481,19 @@ openQueryFunction
     ;
 
 rowSetFunction
-    : openRowSetFunction | openQueryFunction | openDatasourceFunction
+    : openRowSetFunction | openQueryFunction | openDatasourceFunction | 
predictFunction
+    ;
+
+predictFunction
+    : PREDICT LP_ MODEL EQ_ (variableName | literals) COMMA_ DATA EQ_ 
tableName (AS alias)? (COMMA_ RUNTIME EQ_ ONNX)? RP_ WITH LP_ 
predictResultSetDefinition RP_
+    ;
+
+predictResultSetDefinition
+    : predictColumnDefinition (COMMA_ predictColumnDefinition)*
+    ;
+
+predictColumnDefinition
+    : columnName dataType (COLLATE collationName)? (NULL | NOT NULL)?
     ;
 
 regularFunction
diff --git 
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
 
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
index afc7af9bc90..92792f7756c 100644
--- 
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
+++ 
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/DMLStatement.g4
@@ -185,7 +185,7 @@ tableReference
     ;
 
 tableFactor
-    : tableName (FOR PATH)? forSystemTimeClause? (AS? alias)? 
tableSampleClause? withTableHint? | subquery AS? alias columnNames? | expr (AS? 
alias)? columnNames? | xmlMethodCall (AS? alias)? columnNames? | LP_ 
tableReferences RP_ | pivotTable
+    : tableName (FOR PATH)? forSystemTimeClause? (AS? alias)? 
tableSampleClause? withTableHint? | subquery AS? alias columnNames? | 
rowSetFunction (AS? alias)? | expr (AS? alias)? columnNames? | xmlMethodCall 
(AS? alias)? columnNames? | LP_ tableReferences RP_ | pivotTable
     ;
 
 pivotTable
diff --git 
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
 
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
index 9028275ba3d..d7f81f99e72 100644
--- 
a/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
+++ 
b/parser/sql/dialect/sqlserver/src/main/antlr4/imports/sqlserver/SQLServerKeyword.g4
@@ -2170,3 +2170,15 @@ SETS
 DISTRIBUTED_AGG
     : D I S T R I B U T E D UL_ A G G
     ;
+
+PREDICT
+    : P R E D I C T
+    ;
+
+RUNTIME
+    : R U N T I M E
+    ;
+
+ONNX
+    : O N N X
+    ;
diff --git 
a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
 
b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
index bb318aa1b47..d7e1bdebef6 100644
--- 
a/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
+++ 
b/parser/sql/dialect/sqlserver/src/main/java/org/apache/shardingsphere/sql/parser/sqlserver/visitor/statement/SQLServerStatementVisitor.java
@@ -164,6 +164,7 @@ import 
org.apache.shardingsphere.sql.parser.statement.core.enums.ParameterMarker
 import org.apache.shardingsphere.sql.parser.statement.core.enums.ScanUnit;
 import 
org.apache.shardingsphere.sql.parser.statement.core.enums.StatisticsDimension;
 import org.apache.shardingsphere.sql.parser.statement.core.segment.SQLSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dal.VariableSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.constraint.ConstraintSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexNameSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.index.IndexSegment;
@@ -260,6 +261,8 @@ import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.Rol
 import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.GroupingSetsClauseContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.GroupingExprListContext;
 import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.ExpressionListContext;
+import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.VariableNameContext;
+import 
org.apache.shardingsphere.sql.parser.autogen.SQLServerStatementParser.PredictFunctionContext;
 
 import java.util.Collection;
 import java.util.Collections;
@@ -347,6 +350,11 @@ public abstract class SQLServerStatementVisitor extends 
SQLServerStatementBaseVi
         return null == ctx.regularIdentifier() ? 
visit(ctx.delimitedIdentifier()) : visit(ctx.regularIdentifier());
     }
     
+    @Override
+    public final ASTNode visitVariableName(final VariableNameContext ctx) {
+        return new VariableSegment(ctx.getStart().getStartIndex(), 
ctx.getStop().getStopIndex(), ((IdentifierValue) 
visit(ctx.identifier())).getValue());
+    }
+    
     @Override
     public final ASTNode visitRegularIdentifier(final RegularIdentifierContext 
ctx) {
         UnreservedWordContext unreservedWord = ctx.unreservedWord();
@@ -1468,11 +1476,31 @@ public abstract class SQLServerStatementVisitor extends 
SQLServerStatementBaseVi
             return visit(ctx.openRowSetFunction());
         } else if (null != ctx.openQueryFunction()) {
             return visit(ctx.openQueryFunction());
+        } else if (null != ctx.predictFunction()) {
+            return visit(ctx.predictFunction());
         } else {
             return visit(ctx.openDatasourceFunction());
         }
     }
     
+    @Override
+    public ASTNode visitPredictFunction(final PredictFunctionContext ctx) {
+        FunctionSegment result = new 
FunctionSegment(ctx.getStart().getStartIndex(), ctx.getStop().getStopIndex(), 
ctx.PREDICT().getText(), getOriginalText(ctx));
+        if (null != ctx.variableName()) {
+            result.getParameters().add((ExpressionSegment) 
visit(ctx.variableName()));
+        } else if (null != ctx.literals()) {
+            result.getParameters().add((ExpressionSegment) 
visit(ctx.literals()));
+        }
+        result.getParameters().add(new 
LiteralExpressionSegment(ctx.tableName().getStart().getStartIndex(), 
ctx.tableName().getStop().getStopIndex(), ctx.tableName().getText()));
+        if (null != ctx.alias()) {
+            result.getParameters().add(new 
LiteralExpressionSegment(ctx.alias().getStart().getStartIndex(), 
ctx.alias().getStop().getStopIndex(), ctx.alias().getText()));
+        }
+        if (null != ctx.ONNX()) {
+            result.getParameters().add(new 
LiteralExpressionSegment(ctx.ONNX().getSymbol().getStartIndex(), 
ctx.ONNX().getSymbol().getStopIndex(), ctx.ONNX().getText()));
+        }
+        return result;
+    }
+    
     @Override
     public ASTNode visitWithTableHint(final WithTableHintContext ctx) {
         WithTableHintSegment result = new 
WithTableHintSegment(ctx.start.getStartIndex(), ctx.stop.getStopIndex());
@@ -1939,6 +1967,14 @@ public abstract class SQLServerStatementVisitor extends 
SQLServerStatementBaseVi
         if (null != ctx.pivotTable()) {
             return visit(ctx.pivotTable());
         }
+        if (null != ctx.rowSetFunction()) {
+            FunctionSegment functionSegment = (FunctionSegment) 
visit(ctx.rowSetFunction());
+            FunctionTableSegment result = new 
FunctionTableSegment(functionSegment.getStartIndex(), 
functionSegment.getStopIndex(), functionSegment);
+            if (null != ctx.alias()) {
+                result.setAlias((AliasSegment) visit(ctx.alias()));
+            }
+            return result;
+        }
         return visit(ctx.tableReferences());
     }
     
diff --git a/test/it/parser/src/main/resources/case/dml/select.xml 
b/test/it/parser/src/main/resources/case/dml/select.xml
index e0a724ca4fe..d70e3c0624f 100644
--- a/test/it/parser/src/main/resources/case/dml/select.xml
+++ b/test/it/parser/src/main/resources/case/dml/select.xml
@@ -11652,4 +11652,29 @@
             <column-item name="CustomerKey" start-index="51" stop-index="61" />
         </group-by>
     </select>
+
+    <select sql-case-id="select_with_predict_function">
+        <projections start-index="7" stop-index="18">
+            <shorthand-projection start-index="7" stop-index="9">
+                <owner name="d" start-index="7" stop-index="7" />
+            </shorthand-projection>
+            <column-projection name="Score" start-index="12" stop-index="18">
+                <owner name="p" start-index="12" stop-index="12" />
+            </column-projection>
+        </projections>
+        <from start-index="26" stop-index="96">
+            <function-table table-alias="p" start-index="26" stop-index="91">
+                <table-function function-name="PREDICT" text="PREDICT(MODEL = 
@model, DATA = dbo.mytable AS d) WITH (Score FLOAT)">
+                    <parameter>
+                        <parameter-marker value="@model" start-index="40" 
stop-index="45" />
+                    </parameter>
+                    <parameter>
+                        <simple-table name="mytable" alias="d" 
start-index="54" stop-index="70">
+                            <owner name="dbo" start-index="54" stop-index="54" 
/>
+                        </simple-table>
+                    </parameter>
+                </table-function>
+            </function-table>
+        </from>
+    </select>
 </sql-parser-test-cases>
diff --git a/test/it/parser/src/main/resources/sql/supported/dml/select.xml 
b/test/it/parser/src/main/resources/sql/supported/dml/select.xml
index 3b460069f32..0c624a2d463 100644
--- a/test/it/parser/src/main/resources/sql/supported/dml/select.xml
+++ b/test/it/parser/src/main/resources/sql/supported/dml/select.xml
@@ -367,4 +367,5 @@
     <sql-case id="select_grouping_sets" value="SELECT Country, Region, 
SUM(Sales) AS TotalSales FROM Sales GROUP BY GROUPING SETS ( ROLLUP (Country, 
Region), CUBE (Country, Region));" db-types="SQLServer"/>
     <sql-case id="select_identity_function" value="SELECT IDENTITY(int,1,1) AS 
ID_Num INTO NewTable FROM OldTable;" db-types="SQLServer"/>
     <sql-case id="select_group_by_with_distributed_agg" value="SELECT 
CustomerKey FROM FactInternetSales GROUP BY CustomerKey WITH (DISTRIBUTED_AGG)" 
db-types="SQLServer"/>
+    <sql-case id="select_with_predict_function" value="SELECT d.*, p.Score 
FROM PREDICT(MODEL = @model, DATA = dbo.mytable AS d) WITH (Score FLOAT) AS p;" 
db-types="SQLServer"/>
 </sql-cases>

Reply via email to