This is an automated email from the ASF dual-hosted git repository.
chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new b662d7d5d51 Support table and subquery extract for more expression
segment (#33043)
b662d7d5d51 is described below
commit b662d7d5d516afd16fe6c0d4110232222ed7997d
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Sun Sep 29 10:36:33 2024 +0800
Support table and subquery extract for more expression segment (#33043)
---
.../statement/core/util/SubqueryExtractUtils.java | 35 ++++++++++++--
.../parser/statement/core/util/TableExtractor.java | 55 ++++++++++++++++------
2 files changed, 73 insertions(+), 17 deletions(-)
diff --git
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
index 10cb4033871..5dba65e27c5 100644
---
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
+++
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/SubqueryExtractUtils.java
@@ -21,14 +21,18 @@ import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.apache.shardingsphere.sql.parser.statement.core.enums.SubqueryType;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.datetime.DatetimeExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CaseWhenExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CollateExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExistsSubqueryExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.InExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ListExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.TypeCastExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
@@ -36,6 +40,7 @@ import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.Expr
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ProjectionsSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.SubqueryProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.match.MatchAgainstExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.JoinTableSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SubqueryTableSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableSegment;
@@ -147,9 +152,7 @@ public final class SubqueryExtractUtils {
extractSubquerySegments(result, subquery.getSelect());
}
if (expressionSegment instanceof ListExpression) {
- for (ExpressionSegment each : ((ListExpression)
expressionSegment).getItems()) {
- extractSubquerySegmentsFromExpression(result, each,
subqueryType);
- }
+ ((ListExpression) expressionSegment).getItems().forEach(each ->
extractSubquerySegmentsFromExpression(result, each, subqueryType));
}
if (expressionSegment instanceof BinaryOperationExpression) {
extractSubquerySegmentsFromExpression(result,
((BinaryOperationExpression) expressionSegment).getLeft(), subqueryType);
@@ -169,6 +172,32 @@ public final class SubqueryExtractUtils {
if (expressionSegment instanceof FunctionSegment) {
((FunctionSegment) expressionSegment).getParameters().forEach(each
-> extractSubquerySegmentsFromExpression(result, each, subqueryType));
}
+ if (expressionSegment instanceof MatchAgainstExpression) {
+ extractSubquerySegmentsFromExpression(result,
((MatchAgainstExpression) expressionSegment).getExpr(), subqueryType);
+ }
+ if (expressionSegment instanceof CaseWhenExpression) {
+ extractSubquerySegmentsFromCaseWhenExpression(result,
(CaseWhenExpression) expressionSegment, subqueryType);
+ }
+ if (expressionSegment instanceof CollateExpression) {
+ extractSubquerySegmentsFromExpression(result, ((CollateExpression)
expressionSegment).getCollateName(), subqueryType);
+ }
+ if (expressionSegment instanceof DatetimeExpression) {
+ extractSubquerySegmentsFromExpression(result,
((DatetimeExpression) expressionSegment).getLeft(), subqueryType);
+ extractSubquerySegmentsFromExpression(result,
((DatetimeExpression) expressionSegment).getRight(), subqueryType);
+ }
+ if (expressionSegment instanceof NotExpression) {
+ extractSubquerySegmentsFromExpression(result, ((NotExpression)
expressionSegment).getExpression(), subqueryType);
+ }
+ if (expressionSegment instanceof TypeCastExpression) {
+ extractSubquerySegmentsFromExpression(result,
((TypeCastExpression) expressionSegment).getExpression(), subqueryType);
+ }
+ }
+
+ private static void extractSubquerySegmentsFromCaseWhenExpression(final
List<SubquerySegment> result, final CaseWhenExpression expressionSegment, final
SubqueryType subqueryType) {
+ extractSubquerySegmentsFromExpression(result,
expressionSegment.getCaseExpr(), subqueryType);
+ expressionSegment.getWhenExprs().forEach(each ->
extractSubquerySegmentsFromExpression(result, each, subqueryType));
+ expressionSegment.getThenExprs().forEach(each ->
extractSubquerySegmentsFromExpression(result, each, subqueryType));
+ extractSubquerySegmentsFromExpression(result,
expressionSegment.getElseExpr(), subqueryType);
}
private static void extractSubquerySegmentsFromCombine(final
List<SubquerySegment> result, final CombineSegment combineSegment) {
diff --git
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
index be6bbd385dd..7cc6bfe08c2 100644
---
a/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
+++
b/parser/sql/statement/core/src/main/java/org/apache/shardingsphere/sql/parser/statement/core/util/TableExtractor.java
@@ -23,15 +23,21 @@ import
org.apache.shardingsphere.sql.parser.statement.core.segment.ddl.routine.V
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.assignment.ColumnAssignmentSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.combine.CombineSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.datetime.DatetimeExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BetweenExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CaseWhenExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.CollateExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExistsSubqueryExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.FunctionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.InExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.ListExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.NotExpression;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.TypeCastExpression;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.complex.CommonTableExpressionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubqueryExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.expr.subquery.SubquerySegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.AggregationProjectionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ColumnProjectionSegment;
import
org.apache.shardingsphere.sql.parser.statement.core.segment.dml.item.ExpressionProjectionSegment;
@@ -59,6 +65,7 @@ import
org.apache.shardingsphere.sql.parser.statement.core.statement.dml.SelectS
import
org.apache.shardingsphere.sql.parser.statement.core.statement.dml.UpdateStatement;
import java.util.Collection;
+import java.util.Collections;
import java.util.LinkedList;
import java.util.Optional;
@@ -135,14 +142,11 @@ public final class TableExtractor {
}
private void extractTablesFromExpression(final ExpressionSegment
expressionSegment) {
- if (expressionSegment instanceof ColumnSegment && ((ColumnSegment)
expressionSegment).getOwner().isPresent() && needRewrite(((ColumnSegment)
expressionSegment).getOwner().get())) {
- OwnerSegment ownerSegment = ((ColumnSegment)
expressionSegment).getOwner().get();
- rewriteTables.add(new SimpleTableSegment(new
TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(),
ownerSegment.getIdentifier())));
+ if (expressionSegment instanceof ColumnSegment) {
+
extractTablesFromColumnSegments(Collections.singleton((ColumnSegment)
expressionSegment));
}
if (expressionSegment instanceof ListExpression) {
- for (ExpressionSegment each : ((ListExpression)
expressionSegment).getItems()) {
- extractTablesFromExpression(each);
- }
+ ((ListExpression)
expressionSegment).getItems().forEach(this::extractTablesFromExpression);
}
if (expressionSegment instanceof ExistsSubqueryExpression) {
extractTablesFromSelect(((ExistsSubqueryExpression)
expressionSegment).getSubquery().getSelect());
@@ -159,22 +163,45 @@ public final class TableExtractor {
if (expressionSegment instanceof SubqueryExpressionSegment) {
extractTablesFromSelect(((SubqueryExpressionSegment)
expressionSegment).getSubquery().getSelect());
}
+ if (expressionSegment instanceof SubquerySegment) {
+ extractTablesFromSelect(((SubquerySegment)
expressionSegment).getSelect());
+ }
if (expressionSegment instanceof BinaryOperationExpression) {
extractTablesFromExpression(((BinaryOperationExpression)
expressionSegment).getLeft());
extractTablesFromExpression(((BinaryOperationExpression)
expressionSegment).getRight());
}
if (expressionSegment instanceof MatchAgainstExpression) {
- for (ColumnSegment each : ((MatchAgainstExpression)
expressionSegment).getColumns()) {
- extractTablesFromExpression(each);
- }
+ ((MatchAgainstExpression)
expressionSegment).getColumns().forEach(this::extractTablesFromExpression);
+ extractTablesFromExpression(((MatchAgainstExpression)
expressionSegment).getExpr());
}
if (expressionSegment instanceof FunctionSegment) {
- for (ExpressionSegment each : ((FunctionSegment)
expressionSegment).getParameters()) {
- extractTablesFromExpression(each);
- }
+ ((FunctionSegment)
expressionSegment).getParameters().forEach(this::extractTablesFromExpression);
+ }
+ if (expressionSegment instanceof CaseWhenExpression) {
+ extractTablesFromCaseWhenExpression((CaseWhenExpression)
expressionSegment);
+ }
+ if (expressionSegment instanceof CollateExpression) {
+ extractTablesFromExpression(((CollateExpression)
expressionSegment).getCollateName());
+ }
+ if (expressionSegment instanceof DatetimeExpression) {
+ extractTablesFromExpression(((DatetimeExpression)
expressionSegment).getLeft());
+ extractTablesFromExpression(((DatetimeExpression)
expressionSegment).getRight());
+ }
+ if (expressionSegment instanceof NotExpression) {
+ extractTablesFromExpression(((NotExpression)
expressionSegment).getExpression());
+ }
+ if (expressionSegment instanceof TypeCastExpression) {
+ extractTablesFromExpression(((TypeCastExpression)
expressionSegment).getExpression());
}
}
+ private void extractTablesFromCaseWhenExpression(final CaseWhenExpression
expressionSegment) {
+ extractTablesFromExpression(expressionSegment.getCaseExpr());
+
expressionSegment.getWhenExprs().forEach(this::extractTablesFromExpression);
+
expressionSegment.getThenExprs().forEach(this::extractTablesFromExpression);
+ extractTablesFromExpression(expressionSegment.getElseExpr());
+ }
+
private void extractTablesFromProjections(final ProjectionsSegment
projections) {
for (ProjectionSegment each : projections.getProjections()) {
if (each instanceof SubqueryProjectionSegment) {
@@ -253,12 +280,12 @@ public final class TableExtractor {
}
private void extractTablesFromColumnSegments(final
Collection<ColumnSegment> columnSegments) {
- columnSegments.forEach(each -> {
+ for (ColumnSegment each : columnSegments) {
if (each.getOwner().isPresent() &&
needRewrite(each.getOwner().get())) {
OwnerSegment ownerSegment = each.getOwner().get();
rewriteTables.add(new SimpleTableSegment(new
TableNameSegment(ownerSegment.getStartIndex(), ownerSegment.getStopIndex(),
ownerSegment.getIdentifier())));
}
- });
+ }
}
/**