This is an automated email from the ASF dual-hosted git repository.

chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git


The following commit(s) were added to refs/heads/master by this push:
     new b662d7d5d51 Support table and subquery extract for more expression 
segment (#33043)
b662d7d5d51 is described below

commit b662d7d5d516afd16fe6c0d4110232222ed7997d
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Sun Sep 29 10:36:33 2024 +0800

    Support table and subquery extract for more expression segment (#33043)
---
 .../statement/core/util/SubqueryExtractUtils.java  | 35 ++++++++++++--
 .../parser/statement/core/util/TableExtractor.java | 55 ++++++++++++++++------
 2 files changed, 73 insertions(+), 17 deletions(-)

diff --git 
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
 
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
index 10cb4033871..5dba65e27c5 100644
--- 
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
+++ 
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
@@ -21,14 +21,18 @@ import lombok.AccessLevel;
 import lombok.NoArgsConstructor;
 import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.datetime.DatetimeExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CaseWhenExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CollateExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExistsSubqueryExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.InExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ListExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.TypeCastExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
@@ -36,6 +40,7 @@ import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.Expr
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.match.MatchAgainstExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.JoinTableSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
@@ -147,9 +152,7 @@ public final class SubqueryExtractUtils {
             extractSubquerySegments(result, subquery.getSelect());
         }
         if (expressionSegment instanceof ListExpression) {
-            for (ExpressionSegment each : ((ListExpression) 
expressionSegment).getItems()) {
-                extractSubquerySegmentsFromExpression(result, each, 
subqueryType);
-            }
+            ((ListExpression) expressionSegment).getItems().forEach(each -> 
extractSubquerySegmentsFromExpression(result, each, subqueryType));
         }
         if (expressionSegment instanceof BinaryOperationExpression) {
             extractSubquerySegmentsFromExpression(result, 
((BinaryOperationExpression) expressionSegment).getLeft(), subqueryType);
@@ -169,6 +172,32 @@ public final class SubqueryExtractUtils {
         if (expressionSegment instanceof FunctionSegment) {
             ((FunctionSegment) expressionSegment).getParameters().forEach(each 
-> extractSubquerySegmentsFromExpression(result, each, subqueryType));
         }
+        if (expressionSegment instanceof MatchAgainstExpression) {
+            extractSubquerySegmentsFromExpression(result, 
((MatchAgainstExpression) expressionSegment).getExpr(), subqueryType);
+        }
+        if (expressionSegment instanceof CaseWhenExpression) {
+            extractSubquerySegmentsFromCaseWhenExpression(result, 
(CaseWhenExpression) expressionSegment, subqueryType);
+        }
+        if (expressionSegment instanceof CollateExpression) {
+            extractSubquerySegmentsFromExpression(result, ((CollateExpression) 
expressionSegment).getCollateName(), subqueryType);
+        }
+        if (expressionSegment instanceof DatetimeExpression) {
+            extractSubquerySegmentsFromExpression(result, 
((DatetimeExpression) expressionSegment).getLeft(), subqueryType);
+            extractSubquerySegmentsFromExpression(result, 
((DatetimeExpression) expressionSegment).getRight(), subqueryType);
+        }
+        if (expressionSegment instanceof NotExpression) {
+            extractSubquerySegmentsFromExpression(result, ((NotExpression) 
expressionSegment).getExpression(), subqueryType);
+        }
+        if (expressionSegment instanceof TypeCastExpression) {
+            extractSubquerySegmentsFromExpression(result, 
((TypeCastExpression) expressionSegment).getExpression(), subqueryType);
+        }
+    }
+    
+    private static void extractSubquerySegmentsFromCaseWhenExpression(final 
List<SubquerySegment> result, final CaseWhenExpression expressionSegment, final 
SubqueryType subqueryType) {
+        extractSubquerySegmentsFromExpression(result, 
expressionSegment.getCaseExpr(), subqueryType);
+        expressionSegment.getWhenExprs().forEach(each -> 
extractSubquerySegmentsFromExpression(result, each, subqueryType));
+        expressionSegment.getThenExprs().forEach(each -> 
extractSubquerySegmentsFromExpression(result, each, subqueryType));
+        extractSubquerySegmentsFromExpression(result, 
expressionSegment.getElseExpr(), subqueryType);
     }
     
     private static void extractSubquerySegmentsFromCombine(final 
List<SubquerySegment> result, final CombineSegment combineSegment) {
diff --git 
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
 
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
index be6bbd385dd..7cc6bfe08c2 100644
--- 
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
+++ 
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
@@ -23,15 +23,21 @@ import 
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.routine.V
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.ColumnAssignmentSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.datetime.DatetimeExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CaseWhenExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CollateExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExistsSubqueryExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.InExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ListExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.TypeCastExpression;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
+import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
 import 
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
@@ -59,6 +65,7 @@ import 
org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectS
 import 
org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.LinkedList;
 import java.util.Optional;
 
@@ -135,14 +142,11 @@ public final class TableExtractor {
     }
     
     private void extractTablesFromExpression(final ExpressionSegment 
expressionSegment) {
-        if (expressionSegment instanceof ColumnSegment && ((ColumnSegment) 
expressionSegment).getOwner().isPresent() && needRewrite(((ColumnSegment) 
expressionSegment).getOwner().get())) {
-            OwnerSegment ownerSegment = ((ColumnSegment) 
expressionSegment).getOwner().get();
-            rewriteTables.add(new SimpleTableSegment(new 
TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), 
ownerSegment.getIdentifier())));
+        if (expressionSegment instanceof ColumnSegment) {
+            
extractTablesFromColumnSegments(Collections.singleton((ColumnSegment) 
expressionSegment));
         }
         if (expressionSegment instanceof ListExpression) {
-            for (ExpressionSegment each : ((ListExpression) 
expressionSegment).getItems()) {
-                extractTablesFromExpression(each);
-            }
+            ((ListExpression) 
expressionSegment).getItems().forEach(this::extractTablesFromExpression);
         }
         if (expressionSegment instanceof ExistsSubqueryExpression) {
             extractTablesFromSelect(((ExistsSubqueryExpression) 
expressionSegment).getSubquery().getSelect());
@@ -159,22 +163,45 @@ public final class TableExtractor {
         if (expressionSegment instanceof SubqueryExpressionSegment) {
             extractTablesFromSelect(((SubqueryExpressionSegment) 
expressionSegment).getSubquery().getSelect());
         }
+        if (expressionSegment instanceof SubquerySegment) {
+            extractTablesFromSelect(((SubquerySegment) 
expressionSegment).getSelect());
+        }
         if (expressionSegment instanceof BinaryOperationExpression) {
             extractTablesFromExpression(((BinaryOperationExpression) 
expressionSegment).getLeft());
             extractTablesFromExpression(((BinaryOperationExpression) 
expressionSegment).getRight());
         }
         if (expressionSegment instanceof MatchAgainstExpression) {
-            for (ColumnSegment each : ((MatchAgainstExpression) 
expressionSegment).getColumns()) {
-                extractTablesFromExpression(each);
-            }
+            ((MatchAgainstExpression) 
expressionSegment).getColumns().forEach(this::extractTablesFromExpression);
+            extractTablesFromExpression(((MatchAgainstExpression) 
expressionSegment).getExpr());
         }
         if (expressionSegment instanceof FunctionSegment) {
-            for (ExpressionSegment each : ((FunctionSegment) 
expressionSegment).getParameters()) {
-                extractTablesFromExpression(each);
-            }
+            ((FunctionSegment) 
expressionSegment).getParameters().forEach(this::extractTablesFromExpression);
+        }
+        if (expressionSegment instanceof CaseWhenExpression) {
+            extractTablesFromCaseWhenExpression((CaseWhenExpression) 
expressionSegment);
+        }
+        if (expressionSegment instanceof CollateExpression) {
+            extractTablesFromExpression(((CollateExpression) 
expressionSegment).getCollateName());
+        }
+        if (expressionSegment instanceof DatetimeExpression) {
+            extractTablesFromExpression(((DatetimeExpression) 
expressionSegment).getLeft());
+            extractTablesFromExpression(((DatetimeExpression) 
expressionSegment).getRight());
+        }
+        if (expressionSegment instanceof NotExpression) {
+            extractTablesFromExpression(((NotExpression) 
expressionSegment).getExpression());
+        }
+        if (expressionSegment instanceof TypeCastExpression) {
+            extractTablesFromExpression(((TypeCastExpression) 
expressionSegment).getExpression());
         }
     }
     
+    private void extractTablesFromCaseWhenExpression(final CaseWhenExpression 
expressionSegment) {
+        extractTablesFromExpression(expressionSegment.getCaseExpr());
+        
expressionSegment.getWhenExprs().forEach(this::extractTablesFromExpression);
+        
expressionSegment.getThenExprs().forEach(this::extractTablesFromExpression);
+        extractTablesFromExpression(expressionSegment.getElseExpr());
+    }
+    
     private void extractTablesFromProjections(final ProjectionsSegment 
projections) {
         for (ProjectionSegment each : projections.getProjections()) {
             if (each instanceof SubqueryProjectionSegment) {
@@ -253,12 +280,12 @@ public final class TableExtractor {
     }
     
     private void extractTablesFromColumnSegments(final 
Collection<ColumnSegment> columnSegments) {
-        columnSegments.forEach(each -> {
+        for (ColumnSegment each : columnSegments) {
             if (each.getOwner().isPresent() && 
needRewrite(each.getOwner().get())) {
                 OwnerSegment ownerSegment = each.getOwner().get();
                 rewriteTables.add(new SimpleTableSegment(new 
TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(), 
ownerSegment.getIdentifier())));
             }
-        });
+        }
     }
     
     /**

Reply via email to