This is an automated email from the ASF dual-hosted git repository.
chengzhang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new c28ac8162fd Add ParameterMarkerSegmentBinder logic for Oracle
MergeStatementBinder (#29561)
c28ac8162fd is described below
commit c28ac8162fd867e6b64621e899949e957fd5162a
Author: Zhengqiang Duan <[email protected]>
AuthorDate: Wed Dec 27 14:47:46 2023 +0800
Add ParameterMarkerSegmentBinder logic for Oracle MergeStatementBinder
(#29561)
* Add ParameterMarkerSegmentBinder logic for Oracle MergeStatementBinder
* Modify merge getInsert and getUpdate method return type to Optional
* fix unit test
* fix unit test
* fix checkstyle
---
.../parameter/ParameterMarkerSegmentBinder.java | 59 +++++++++++++++++
.../ParameterMarkerExpressionSegmentBinder.java | 76 ++++++++++++++++++++++
.../binder/statement/dml/MergeStatementBinder.java | 69 +++++++++++++++-----
.../binder/statement/MergeStatementBinderTest.java | 18 +++--
.../statement/merge/MergeStatementConverter.java | 5 +-
.../simple/ParameterMarkerExpressionSegment.java | 20 +++++-
.../limit/ParameterMarkerLimitValueSegment.java | 15 +++++
.../ParameterMarkerRowNumberValueSegment.java | 15 +++++
.../segment/generic/ParameterMarkerSegment.java | 8 +++
.../sql/common/statement/AbstractSQLStatement.java | 3 +-
.../sql/common/statement/dml/MergeStatement.java | 20 ++++++
.../statement/dml/impl/MergeStatementAssert.java | 19 +++---
12 files changed, 289 insertions(+), 38 deletions(-)
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/ParameterMarkerSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/ParameterMarkerSegmentBinder.java
new file mode 100644
index 00000000000..a036bd697bf
--- /dev/null
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/ParameterMarkerSegmentBinder.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.segment.parameter;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import
org.apache.shardingsphere.infra.binder.segment.parameter.impl.ParameterMarkerExpressionSegmentBinder;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
+
+import java.util.Collection;
+import java.util.LinkedList;
+import java.util.Map;
+
+/**
+ * Parameter marker segment binder.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class ParameterMarkerSegmentBinder {
+
+ /**
+ * Bind parameter marker segment with metadata.
+ *
+ * @param parameterMarkerSegments parameter marker segments
+ * @param parameterMarkerSegmentBoundedInfos parameter marker segment
bounded infos
+ * @return bounded parameter marker segment
+ */
+ public static Collection<ParameterMarkerSegment> bind(final
Collection<ParameterMarkerSegment> parameterMarkerSegments,
+ final
Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo>
parameterMarkerSegmentBoundedInfos) {
+ Collection<ParameterMarkerSegment> result = new LinkedList<>();
+ parameterMarkerSegments.forEach(each -> result.add(bind(each,
parameterMarkerSegmentBoundedInfos)));
+ return result;
+ }
+
+ private static ParameterMarkerSegment bind(final ParameterMarkerSegment
parameterMarkerSegment,
+ final
Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo>
parameterMarkerSegmentBoundedInfos) {
+ if (parameterMarkerSegment instanceof
ParameterMarkerExpressionSegment) {
+ return
ParameterMarkerExpressionSegmentBinder.bind((ParameterMarkerExpressionSegment)
parameterMarkerSegment, parameterMarkerSegmentBoundedInfos);
+ }
+ // TODO support more ParameterMarkerSegment bind
+ return parameterMarkerSegment;
+ }
+}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/impl/ParameterMarkerExpressionSegmentBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/impl/ParameterMarkerExpressionSegmentBinder.java
new file mode 100644
index 00000000000..fbfd4b0a7bb
--- /dev/null
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/parameter/impl/ParameterMarkerExpressionSegmentBinder.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.binder.segment.parameter.impl;
+
+import lombok.AccessLevel;
+import lombok.NoArgsConstructor;
+import org.apache.shardingsphere.infra.binder.enums.SegmentType;
+import
org.apache.shardingsphere.infra.binder.segment.expression.ExpressionSegmentBinder;
+import
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
+import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
+
+import java.util.Map;
+
+/**
+ * Binary operation expression binder.
+ */
+@NoArgsConstructor(access = AccessLevel.PRIVATE)
+public final class ParameterMarkerExpressionSegmentBinder {
+
+ /**
+ * Bind binary operation expression with metadata.
+ *
+ * @param segment binary operation expression segment
+ * @param parentSegmentType parent segment type
+ * @param statementBinderContext statement binder context
+ * @param tableBinderContexts table binder contexts
+ * @param outerTableBinderContexts outer table binder contexts
+ * @return bounded binary operation expression segment
+ */
+ public static BinaryOperationExpression bind(final
BinaryOperationExpression segment, final SegmentType parentSegmentType, final
SQLStatementBinderContext statementBinderContext,
+ final Map<String,
TableSegmentBinderContext> tableBinderContexts, final Map<String,
TableSegmentBinderContext> outerTableBinderContexts) {
+ ExpressionSegment boundedLeft =
ExpressionSegmentBinder.bind(segment.getLeft(), parentSegmentType,
statementBinderContext, tableBinderContexts, outerTableBinderContexts);
+ ExpressionSegment boundedRight =
ExpressionSegmentBinder.bind(segment.getRight(), parentSegmentType,
statementBinderContext, tableBinderContexts, outerTableBinderContexts);
+ return new BinaryOperationExpression(segment.getStartIndex(),
segment.getStopIndex(), boundedLeft, boundedRight, segment.getOperator(),
segment.getText());
+ }
+
+ /**
+ * Bind parameter marker expression segment with metadata.
+ *
+ * @param segment parameter marker expression segment
+ * @param boundedInfos parameter marker expression segment bounded info map
+ * @return bounded parameter marker expression segment
+ */
+ public static ParameterMarkerExpressionSegment bind(final
ParameterMarkerExpressionSegment segment,
+ final
Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo> boundedInfos) {
+ ColumnSegmentBoundedInfo boundedInfo = boundedInfos.get(segment);
+ if (null != boundedInfo) {
+ ParameterMarkerExpressionSegment result =
+ new
ParameterMarkerExpressionSegment(segment.getStartIndex(),
segment.getStopIndex(), segment.getParameterMarkerIndex(),
segment.getParameterMarkerType());
+ segment.getAliasSegment().ifPresent(result::setAlias);
+ result.setBoundedInfo(boundedInfo);
+ return result;
+ }
+ return segment;
+ }
+}
diff --git
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
index 7346d4b4616..53e33e8feb1 100644
---
a/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
+++
b/infra/binder/src/main/java/org/apache/shardingsphere/infra/binder/statement/dml/MergeStatementBinder.java
@@ -20,10 +20,12 @@ package
org.apache.shardingsphere.infra.binder.statement.dml;
import com.cedarsoftware.util.CaseInsensitiveMap;
import lombok.SneakyThrows;
import org.apache.shardingsphere.infra.binder.enums.SegmentType;
+import
org.apache.shardingsphere.infra.binder.segment.column.InsertColumnsSegmentBinder;
import
org.apache.shardingsphere.infra.binder.segment.expression.ExpressionSegmentBinder;
import
org.apache.shardingsphere.infra.binder.segment.expression.impl.ColumnSegmentBinder;
import org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinder;
import
org.apache.shardingsphere.infra.binder.segment.from.TableSegmentBinderContext;
+import
org.apache.shardingsphere.infra.binder.segment.parameter.ParameterMarkerSegmentBinder;
import org.apache.shardingsphere.infra.binder.segment.where.WhereSegmentBinder;
import org.apache.shardingsphere.infra.binder.statement.SQLStatementBinder;
import
org.apache.shardingsphere.infra.binder.statement.SQLStatementBinderContext;
@@ -33,8 +35,13 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.Co
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.InsertValuesSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.SetAssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.InsertColumnsSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionWithParamsSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ColumnProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.dml.InsertStatement;
@@ -50,7 +57,6 @@ import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
/**
* Merge statement binder.
@@ -77,37 +83,52 @@ public final class MergeStatementBinder implements
SQLStatementBinder<MergeState
Map<String, TableSegmentBinderContext> tableBinderContexts = new
LinkedHashMap<>();
tableBinderContexts.putAll(sourceTableBinderContexts);
tableBinderContexts.putAll(targetTableBinderContexts);
- if (sqlStatement.getExpression() != null) {
+ if (null != sqlStatement.getExpression()) {
ExpressionWithParamsSegment expression = new
ExpressionWithParamsSegment(sqlStatement.getExpression().getStartIndex(),
sqlStatement.getExpression().getStopIndex(),
ExpressionSegmentBinder.bind(sqlStatement.getExpression().getExpr(),
SegmentType.JOIN_ON, statementBinderContext, tableBinderContexts,
Collections.emptyMap()));
expression.getParameterMarkerSegments().addAll(sqlStatement.getExpression().getParameterMarkerSegments());
result.setExpression(expression);
}
-
result.setInsert(Optional.ofNullable(sqlStatement.getInsert()).map(optional ->
bindMergeInsert(optional,
- (SimpleTableSegment) boundedTargetTableSegment,
statementBinderContext, targetTableBinderContexts,
sourceTableBinderContexts)).orElse(null));
-
result.setUpdate(Optional.ofNullable(sqlStatement.getUpdate()).map(optional ->
bindMergeUpdate(optional,
- (SimpleTableSegment) boundedTargetTableSegment,
statementBinderContext, targetTableBinderContexts,
sourceTableBinderContexts)).orElse(null));
-
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+ sqlStatement.getInsert().ifPresent(
+ optional -> result.setInsert(bindMergeInsert(optional,
(SimpleTableSegment) boundedTargetTableSegment, statementBinderContext,
targetTableBinderContexts, sourceTableBinderContexts)));
+ sqlStatement.getUpdate().ifPresent(
+ optional -> result.setUpdate(bindMergeUpdate(optional,
(SimpleTableSegment) boundedTargetTableSegment, statementBinderContext,
targetTableBinderContexts, sourceTableBinderContexts)));
+ addParameterMarkerSegments(sqlStatement, result);
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
}
+ private void addParameterMarkerSegments(final MergeStatement
mergeStatement, final MergeStatement originalSQLStatement) {
+
mergeStatement.addParameterMarkerSegments(originalSQLStatement.getParameterMarkerSegments());
+ mergeStatement.getInsert().ifPresent(optional ->
mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
+ mergeStatement.getUpdate().ifPresent(optional ->
mergeStatement.addParameterMarkerSegments(optional.getParameterMarkerSegments()));
+ }
+
@SneakyThrows
private InsertStatement bindMergeInsert(final InsertStatement
sqlStatement, final SimpleTableSegment tableSegment, final
SQLStatementBinderContext statementBinderContext,
final Map<String,
TableSegmentBinderContext> targetTableBinderContexts, final Map<String,
TableSegmentBinderContext> sourceTableBinderContexts) {
- InsertStatement result =
sqlStatement.getClass().getDeclaredConstructor().newInstance();
- result.setTable(tableSegment);
- sqlStatement.getInsertColumns().ifPresent(result::setInsertColumns);
- sqlStatement.getInsertSelect().ifPresent(result::setInsertSelect);
SQLStatementBinderContext insertStatementBinderContext = new
SQLStatementBinderContext(statementBinderContext.getMetaData(),
statementBinderContext.getDefaultDatabaseName(),
statementBinderContext.getDatabaseType(),
statementBinderContext.getVariableNames());
insertStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
insertStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
+ InsertStatement result =
sqlStatement.getClass().getDeclaredConstructor().newInstance();
+ result.setTable(tableSegment);
+ sqlStatement.getInsertColumns()
+ .ifPresent(optional ->
result.setInsertColumns(InsertColumnsSegmentBinder.bind(sqlStatement.getInsertColumns().get(),
statementBinderContext, targetTableBinderContexts)));
+ sqlStatement.getInsertSelect().ifPresent(result::setInsertSelect);
Collection<InsertValuesSegment> insertValues = new LinkedList<>();
+ Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo>
parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
+ List<ColumnSegment> columnSegments = new
ArrayList<>(result.getInsertColumns().map(InsertColumnsSegment::getColumns)
+ .orElseGet(() ->
getVisibleColumns(targetTableBinderContexts.values().iterator().next().getProjectionSegments())));
for (InsertValuesSegment each : sqlStatement.getValues()) {
List<ExpressionSegment> values = new LinkedList<>();
- for (ExpressionSegment value : each.getValues()) {
- values.add(ExpressionSegmentBinder.bind(value,
SegmentType.VALUES, insertStatementBinderContext, targetTableBinderContexts,
sourceTableBinderContexts));
+ int index = 0;
+ for (ExpressionSegment expression : each.getValues()) {
+ values.add(ExpressionSegmentBinder.bind(expression,
SegmentType.VALUES, insertStatementBinderContext, targetTableBinderContexts,
sourceTableBinderContexts));
+ if (expression instanceof ParameterMarkerSegment) {
+
parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression,
columnSegments.get(index).getColumnBoundedInfo());
+ }
+ index++;
}
insertValues.add(new InsertValuesSegment(each.getStartIndex(),
each.getStopIndex(), values));
}
@@ -122,11 +143,21 @@ public final class MergeStatementBinder implements
SQLStatementBinder<MergeState
InsertStatementHandler.getReturningSegment(sqlStatement).ifPresent(optional ->
InsertStatementHandler.setReturningSegment(result, optional));
InsertStatementHandler.getWhereSegment(sqlStatement).ifPresent(optional ->
InsertStatementHandler.setWhereSegment(result,
WhereSegmentBinder.bind(optional,
insertStatementBinderContext, targetTableBinderContexts,
sourceTableBinderContexts)));
-
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+
result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(),
parameterMarkerSegmentBoundedInfos));
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
}
+ private Collection<ColumnSegment> getVisibleColumns(final
Collection<ProjectionSegment> projectionSegments) {
+ Collection<ColumnSegment> result = new LinkedList<>();
+ for (ProjectionSegment each : projectionSegments) {
+ if (each instanceof ColumnProjectionSegment && each.isVisible()) {
+ result.add(((ColumnProjectionSegment) each).getColumn());
+ }
+ }
+ return result;
+ }
+
@SneakyThrows
private UpdateStatement bindMergeUpdate(final UpdateStatement
sqlStatement, final SimpleTableSegment tableSegment, final
SQLStatementBinderContext statementBinderContext,
final Map<String,
TableSegmentBinderContext> targetTableBinderContexts, final Map<String,
TableSegmentBinderContext> sourceTableBinderContexts) {
@@ -137,13 +168,17 @@ public final class MergeStatementBinder implements
SQLStatementBinder<MergeState
statementBinderContext.getDatabaseType(),
statementBinderContext.getVariableNames());
updateStatementBinderContext.getExternalTableBinderContexts().putAll(statementBinderContext.getExternalTableBinderContexts());
updateStatementBinderContext.getExternalTableBinderContexts().putAll(sourceTableBinderContexts);
+ Map<ParameterMarkerSegment, ColumnSegmentBoundedInfo>
parameterMarkerSegmentBoundedInfos = new LinkedHashMap<>();
for (AssignmentSegment each :
sqlStatement.getSetAssignment().getAssignments()) {
List<ColumnSegment> columnSegments = new
ArrayList<>(each.getColumns().size());
each.getColumns().forEach(column -> columnSegments.add(
ColumnSegmentBinder.bind(column,
SegmentType.SET_ASSIGNMENT, updateStatementBinderContext,
targetTableBinderContexts, Collections.emptyMap())));
- ExpressionSegment value =
ExpressionSegmentBinder.bind(each.getValue(), SegmentType.SET_ASSIGNMENT,
updateStatementBinderContext, targetTableBinderContexts,
Collections.emptyMap());
- ColumnAssignmentSegment columnAssignmentSegment = new
ColumnAssignmentSegment(each.getStartIndex(), each.getStopIndex(),
columnSegments, value);
+ ExpressionSegment expression =
ExpressionSegmentBinder.bind(each.getValue(), SegmentType.SET_ASSIGNMENT,
updateStatementBinderContext, targetTableBinderContexts,
Collections.emptyMap());
+ ColumnAssignmentSegment columnAssignmentSegment = new
ColumnAssignmentSegment(each.getStartIndex(), each.getStopIndex(),
columnSegments, expression);
assignments.add(columnAssignmentSegment);
+ if (expression instanceof ParameterMarkerSegment) {
+
parameterMarkerSegmentBoundedInfos.put((ParameterMarkerSegment) expression,
columnAssignmentSegment.getColumns().get(0).getColumnBoundedInfo());
+ }
}
SetAssignmentSegment setAssignmentSegment = new
SetAssignmentSegment(sqlStatement.getSetAssignment().getStartIndex(),
sqlStatement.getSetAssignment().getStopIndex(), assignments);
result.setSetAssignment(setAssignmentSegment);
@@ -153,7 +188,7 @@ public final class MergeStatementBinder implements
SQLStatementBinder<MergeState
UpdateStatementHandler.getOrderBySegment(sqlStatement).ifPresent(optional ->
UpdateStatementHandler.setOrderBySegment(result, optional));
UpdateStatementHandler.getLimitSegment(sqlStatement).ifPresent(optional ->
UpdateStatementHandler.setLimitSegment(result, optional));
UpdateStatementHandler.getWithSegment(sqlStatement).ifPresent(optional
-> UpdateStatementHandler.setWithSegment(result, optional));
-
result.addParameterMarkerSegments(sqlStatement.getParameterMarkerSegments());
+
result.addParameterMarkerSegments(ParameterMarkerSegmentBinder.bind(sqlStatement.getParameterMarkerSegments(),
parameterMarkerSegmentBoundedInfos));
result.getCommentSegments().addAll(sqlStatement.getCommentSegments());
return result;
}
diff --git
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
index 07d1d284474..ad3c13e1978 100644
---
a/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
+++
b/infra/binder/src/test/java/org/apache/shardingsphere/infra/binder/statement/MergeStatementBinderTest.java
@@ -53,6 +53,7 @@ import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@@ -88,9 +89,11 @@ class MergeStatementBinderTest {
assertThat(actual.getSource(), instanceOf(SimpleTableSegment.class));
assertThat(actual.getTarget(), not(mergeStatement.getTarget()));
assertThat(actual.getTarget(), instanceOf(SimpleTableSegment.class));
- assertThat(actual.getUpdate(), not(mergeStatement.getUpdate()));
-
assertThat(actual.getUpdate().getSetAssignment().getAssignments().iterator().next().getValue(),
instanceOf(ColumnSegment.class));
- assertThat(((ColumnSegment)
actual.getUpdate().getSetAssignment().getAssignments().iterator().next().getValue()).getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order_item"));
+ assertTrue(actual.getUpdate().isPresent());
+ assertThat(actual.getUpdate().get(), not(mergeStatement.getUpdate()));
+
assertThat(actual.getUpdate().get().getSetAssignment().getAssignments().iterator().next().getValue(),
instanceOf(ColumnSegment.class));
+ assertThat(((ColumnSegment)
actual.getUpdate().get().getSetAssignment().getAssignments().iterator().next().getValue()).getColumnBoundedInfo().getOriginalTable().getValue(),
+ is("t_order_item"));
}
private ShardingSphereMetaData createMetaData() {
@@ -163,10 +166,11 @@ class MergeStatementBinderTest {
new LiteralExpressionSegment(0, 0, 1), "=", "item_id = 1")));
mergeStatement.setUpdate(updateStatement);
MergeStatement actual = new
MergeStatementBinder().bind(mergeStatement, createMetaData(),
DefaultDatabase.LOGIC_NAME);
- assertThat(actual.getUpdate(),
instanceOf(OracleUpdateStatement.class));
- assertThat(((OracleUpdateStatement)
actual.getUpdate()).getDeleteWhere().getExpr(),
instanceOf(BinaryOperationExpression.class));
- assertThat(((BinaryOperationExpression) ((OracleUpdateStatement)
actual.getUpdate()).getDeleteWhere().getExpr()).getLeft(),
instanceOf(ColumnSegment.class));
- assertThat(((ColumnSegment) ((BinaryOperationExpression)
((OracleUpdateStatement)
actual.getUpdate()).getDeleteWhere().getExpr()).getLeft())
+ assertTrue(actual.getUpdate().isPresent());
+ assertThat(actual.getUpdate().get(),
instanceOf(OracleUpdateStatement.class));
+ assertThat(((OracleUpdateStatement)
actual.getUpdate().get()).getDeleteWhere().getExpr(),
instanceOf(BinaryOperationExpression.class));
+ assertThat(((BinaryOperationExpression) ((OracleUpdateStatement)
actual.getUpdate().get()).getDeleteWhere().getExpr()).getLeft(),
instanceOf(ColumnSegment.class));
+ assertThat(((ColumnSegment) ((BinaryOperationExpression)
((OracleUpdateStatement)
actual.getUpdate().get()).getDeleteWhere().getExpr()).getLeft())
.getColumnBoundedInfo().getOriginalTable().getValue(),
is("t_order_item"));
}
}
diff --git
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/merge/MergeStatementConverter.java
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/merge/MergeStatementConverter.java
index 1c6708598d7..9c18f88890e 100644
---
a/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/merge/MergeStatementConverter.java
+++
b/kernel/sql-federation/optimizer/src/main/java/org/apache/shardingsphere/sqlfederation/optimizer/converter/statement/merge/MergeStatementConverter.java
@@ -46,10 +46,7 @@ public final class MergeStatementConverter implements
SQLStatementConverter<Merg
SqlNode targetTable =
TableConverter.convert(mergeStatement.getTarget()).orElseThrow(IllegalStateException::new);
SqlNode condition =
ExpressionConverter.convert(mergeStatement.getExpression().getExpr()).orElseThrow(IllegalStateException::new);
SqlNode sourceTable =
TableConverter.convert(mergeStatement.getSource()).orElseThrow(IllegalStateException::new);
- SqlUpdate sqlUpdate = null;
- if (null != mergeStatement.getUpdate()) {
- sqlUpdate = convertUpdate(mergeStatement.getUpdate());
- }
+ SqlUpdate sqlUpdate =
mergeStatement.getUpdate().map(this::convertUpdate).orElse(null);
return new SqlMerge(SqlParserPos.ZERO, targetTable, condition,
sourceTable, sqlUpdate, null, null, null);
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/simple/ParameterMarkerExpressionSegment.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/simple/ParameterMarkerExpressionSegment.java
index 3c26c4ac82e..a48fdf770c0 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/simple/ParameterMarkerExpressionSegment.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/expr/simple/ParameterMarkerExpressionSegment.java
@@ -26,6 +26,7 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.Projecti
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasAvailable;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.AliasSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.ParameterMarkerSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import java.util.Optional;
@@ -35,7 +36,7 @@ import java.util.Optional;
*/
@RequiredArgsConstructor
@Getter
-@EqualsAndHashCode
+@EqualsAndHashCode(exclude = "boundedInfo")
public class ParameterMarkerExpressionSegment implements
SimpleExpressionSegment, ProjectionSegment, AliasAvailable,
ParameterMarkerSegment {
private final int startIndex;
@@ -49,6 +50,9 @@ public class ParameterMarkerExpressionSegment implements
SimpleExpressionSegment
@Setter
private AliasSegment alias;
+ @Setter
+ private ColumnSegmentBoundedInfo boundedInfo;
+
public ParameterMarkerExpressionSegment(final int startIndex, final int
stopIndex, final int parameterMarkerIndex) {
this.startIndex = startIndex;
this.stopIndex = stopIndex;
@@ -76,6 +80,11 @@ public class ParameterMarkerExpressionSegment implements
SimpleExpressionSegment
return parameterMarkerIndex;
}
+ @Override
+ public ColumnSegmentBoundedInfo getBoundedInfo() {
+ return Optional.ofNullable(boundedInfo).orElseGet(() -> new
ColumnSegmentBoundedInfo(new IdentifierValue("")));
+ }
+
@Override
public int getStopIndex() {
return null == alias ? stopIndex : alias.getStopIndex();
@@ -85,4 +94,13 @@ public class ParameterMarkerExpressionSegment implements
SimpleExpressionSegment
public String getText() {
return parameterMarkerType.getMarker();
}
+
+ /**
+ * Get alias segment.
+ *
+ * @return alias segment
+ */
+ public Optional<AliasSegment> getAliasSegment() {
+ return Optional.ofNullable(alias);
+ }
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/limit/ParameterMarkerLimitValueSegment.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/limit/ParameterMarkerLimitValueSegment.java
index 4f4f04abb40..ac6c4f3629e 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/limit/ParameterMarkerLimitValueSegment.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/limit/ParameterMarkerLimitValueSegment.java
@@ -17,19 +17,34 @@
package
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.limit;
+import lombok.EqualsAndHashCode;
import lombok.Getter;
+import lombok.Setter;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.ParameterMarkerPaginationValueSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+
+import java.util.Optional;
/**
* Limit value segment for parameter marker.
*/
@Getter
+@EqualsAndHashCode(exclude = "boundedInfo", callSuper = true)
public final class ParameterMarkerLimitValueSegment extends LimitValueSegment
implements ParameterMarkerPaginationValueSegment {
private final int parameterIndex;
+ @Setter
+ private ColumnSegmentBoundedInfo boundedInfo;
+
public ParameterMarkerLimitValueSegment(final int startIndex, final int
stopIndex, final int paramIndex) {
super(startIndex, stopIndex);
this.parameterIndex = paramIndex;
}
+
+ @Override
+ public ColumnSegmentBoundedInfo getBoundedInfo() {
+ return Optional.ofNullable(boundedInfo).orElseGet(() -> new
ColumnSegmentBoundedInfo(new IdentifierValue("")));
+ }
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/rownum/ParameterMarkerRowNumberValueSegment.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/rownum/ParameterMarkerRowNumberValueSegment.java
index b522755f171..f4114dee9a7 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/rownum/ParameterMarkerRowNumberValueSegment.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/dml/pagination/rownum/ParameterMarkerRowNumberValueSegment.java
@@ -17,19 +17,34 @@
package
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.rownum;
+import lombok.EqualsAndHashCode;
import lombok.Getter;
+import lombok.Setter;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.pagination.ParameterMarkerPaginationValueSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
+
+import java.util.Optional;
/**
* Row number value segment for parameter marker.
*/
@Getter
+@EqualsAndHashCode(exclude = "boundedInfo", callSuper = true)
public final class ParameterMarkerRowNumberValueSegment extends
RowNumberValueSegment implements ParameterMarkerPaginationValueSegment {
private final int parameterIndex;
+ @Setter
+ private ColumnSegmentBoundedInfo boundedInfo;
+
public ParameterMarkerRowNumberValueSegment(final int startIndex, final
int stopIndex, final int paramIndex, final boolean boundOpened) {
super(startIndex, stopIndex, boundOpened);
this.parameterIndex = paramIndex;
}
+
+ @Override
+ public ColumnSegmentBoundedInfo getBoundedInfo() {
+ return Optional.ofNullable(boundedInfo).orElseGet(() -> new
ColumnSegmentBoundedInfo(new IdentifierValue("")));
+ }
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/ParameterMarkerSegment.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/ParameterMarkerSegment.java
index 4ffa871af49..1f1cb450a7d 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/ParameterMarkerSegment.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/segment/generic/ParameterMarkerSegment.java
@@ -18,6 +18,7 @@
package org.apache.shardingsphere.sql.parser.sql.common.segment.generic;
import org.apache.shardingsphere.sql.parser.sql.common.segment.SQLSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.bounded.ColumnSegmentBoundedInfo;
/**
* Parameter marker segment.
@@ -30,4 +31,11 @@ public interface ParameterMarkerSegment extends SQLSegment {
* @return parameter index
*/
int getParameterIndex();
+
+ /**
+ * Get bounded info.
+ *
+ * @return bounded info
+ */
+ ColumnSegmentBoundedInfo getBoundedInfo();
}
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/AbstractSQLStatement.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/AbstractSQLStatement.java
index b65e5670397..c0a56b0602d 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/AbstractSQLStatement.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/AbstractSQLStatement.java
@@ -23,6 +23,7 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.Parameter
import java.util.Collection;
import java.util.HashSet;
+import java.util.LinkedHashSet;
import java.util.LinkedList;
/**
@@ -31,7 +32,7 @@ import java.util.LinkedList;
@Getter
public abstract class AbstractSQLStatement implements SQLStatement {
- private final Collection<ParameterMarkerSegment> parameterMarkerSegments =
new LinkedList<>();
+ private final Collection<ParameterMarkerSegment> parameterMarkerSegments =
new LinkedHashSet<>();
private final Collection<Integer> uniqueParameterIndexes = new HashSet<>();
diff --git
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java
index e43a4fedc5a..08f254b9d18 100644
---
a/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java
+++
b/parser/sql/statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/statement/dml/MergeStatement.java
@@ -23,6 +23,8 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.Expressi
import
org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.TableSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.statement.AbstractSQLStatement;
+import java.util.Optional;
+
/**
* Merge statement.
*/
@@ -39,4 +41,22 @@ public abstract class MergeStatement extends
AbstractSQLStatement implements DML
private UpdateStatement update;
private InsertStatement insert;
+
+ /**
+ * Get update statement.
+ *
+ * @return update statement
+ */
+ public Optional<UpdateStatement> getUpdate() {
+ return Optional.ofNullable(update);
+ }
+
+ /**
+ * Get insert statement.
+ *
+ * @return insert statement
+ */
+ public Optional<InsertStatement> getInsert() {
+ return Optional.ofNullable(insert);
+ }
}
diff --git
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java
index 26a01d1d48c..0275db88236 100644
---
a/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java
+++
b/test/it/parser/src/main/java/org/apache/shardingsphere/test/it/sql/parser/internal/asserts/statement/dml/impl/MergeStatementAssert.java
@@ -75,26 +75,29 @@ public final class MergeStatementAssert {
private static void assertSetClause(final SQLCaseAssertContext
assertContext, final MergeStatement actual, final MergeStatementTestCase
expected) {
if (null != expected.getUpdateClause()) {
+ assertTrue(actual.getUpdate().isPresent(),
assertContext.getText("Actual merge update statement should exist."));
if (null == expected.getUpdateClause().getSetClause()) {
- assertNull(actual.getUpdate().getSetAssignment(),
assertContext.getText("Actual assignment should not exist."));
+ assertNull(actual.getUpdate().get().getSetAssignment(),
assertContext.getText("Actual assignment should not exist."));
} else {
- SetClauseAssert.assertIs(assertContext,
actual.getUpdate().getSetAssignment(),
expected.getUpdateClause().getSetClause());
+ SetClauseAssert.assertIs(assertContext,
actual.getUpdate().get().getSetAssignment(),
expected.getUpdateClause().getSetClause());
}
}
}
private static void assertWhereClause(final SQLCaseAssertContext
assertContext, final MergeStatement actual, final MergeStatementTestCase
expected) {
if (null != expected.getUpdateClause()) {
+ assertTrue(actual.getUpdate().isPresent(),
assertContext.getText("Actual merge update statement should exist."));
if (null == expected.getUpdateClause().getWhereClause()) {
- assertFalse(actual.getUpdate().getWhere().isPresent(),
assertContext.getText("Actual update where segment should not exist."));
+ assertFalse(actual.getUpdate().get().getWhere().isPresent(),
assertContext.getText("Actual update where segment should not exist."));
} else {
- assertTrue(actual.getUpdate().getWhere().isPresent(),
assertContext.getText("Actual update where segment should exist."));
- WhereClauseAssert.assertIs(assertContext,
actual.getUpdate().getWhere().get(),
expected.getUpdateClause().getWhereClause());
+ assertTrue(actual.getUpdate().get().getWhere().isPresent(),
assertContext.getText("Actual update where segment should exist."));
+ WhereClauseAssert.assertIs(assertContext,
actual.getUpdate().get().getWhere().get(),
expected.getUpdateClause().getWhereClause());
}
}
- if (null != expected.getInsertClause() && null !=
expected.getInsertClause().getWhereClause() && actual.getInsert() instanceof
OracleInsertStatement) {
- assertTrue(((OracleInsertStatement)
actual.getInsert()).getWhere().isPresent(), assertContext.getText("Actual
insert where segment should exist."));
- WhereClauseAssert.assertIs(assertContext, ((OracleInsertStatement)
actual.getInsert()).getWhere().get(),
expected.getInsertClause().getWhereClause());
+ if (null != expected.getInsertClause() && null !=
expected.getInsertClause().getWhereClause() && actual.getInsert().orElse(null)
instanceof OracleInsertStatement) {
+ assertTrue(actual.getInsert().isPresent(),
assertContext.getText("Actual merge insert statement should exist."));
+ assertTrue(((OracleInsertStatement)
actual.getInsert().get()).getWhere().isPresent(), assertContext.getText("Actual
insert where segment should exist."));
+ WhereClauseAssert.assertIs(assertContext, ((OracleInsertStatement)
actual.getInsert().get()).getWhere().get(),
expected.getInsertClause().getWhereClause());
}
}
}