This is an automated email from the ASF dual-hosted git repository.
duanzhengqiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/shardingsphere.git
The following commit(s) were added to refs/heads/master by this push:
new a5b7aff fix mysql insert sql bug; sql convert to a tree, if
assignment is ins… (#12295)
a5b7aff is described below
commit a5b7aff8702777e1ccfc97c66ebc7588cbf5b51c
Author: CodingBingo <[email protected]>
AuthorDate: Fri Sep 17 07:46:37 2021 +0800
fix mysql insert sql bug; sql convert to a tree, if assignment is ins…
(#12295)
* fix mysql insert sql bug; sql convert to a tree, if assignment is
instanceof BinaryOperationExpression, parameter count should be added.
for more detail: https://github.com/apache/shardingsphere/issues/12272
* 调整过checkStyke
* add unit test code to keep this logic right.
* 1、replace assertsEquals with assertsThat
2、extra common method in InsertValueContext and OnDuplicateUpdateContext
* add header licence
* move ExpressionSegmentUtil logic to ExpressionExtraUtil
* remove import
* move logic to special class
* add java doc, and delete blank line.
---
.../segment/insert/values/InsertValueContext.java | 37 +++++---------
.../insert/values/OnDuplicateUpdateContext.java | 33 ++++---------
.../insert/values/InsertValueContextTest.java | 22 +++++++--
.../values/OnDuplicateUpdateContextTest.java | 56 +++++++++++++---------
.../statement/InsertContextExpressSegmentUtil.java | 54 +++++++++++++++++++++
.../sql/common/util/ExpressionExtractUtil.java | 2 +-
6 files changed, 127 insertions(+), 77 deletions(-)
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContext.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContext.java
index 1c2f270..c4a2be7 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContext.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContext.java
@@ -19,6 +19,7 @@ package
org.apache.shardingsphere.infra.binder.segment.insert.values;
import lombok.Getter;
import lombok.ToString;
+import
org.apache.shardingsphere.infra.statement.InsertContextExpressSegmentUtil;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
@@ -38,25 +39,18 @@ public final class InsertValueContext {
private final int parameterCount;
private final List<ExpressionSegment> valueExpressions;
+
+ private final List<ParameterMarkerExpressionSegment>
parametersValueExpressions;
private final List<Object> parameters;
public InsertValueContext(final Collection<ExpressionSegment> assignments,
final List<Object> parameters, final int parametersOffset) {
- parameterCount = calculateParameterCount(assignments);
valueExpressions = getValueExpressions(assignments);
+ parametersValueExpressions =
InsertContextExpressSegmentUtil.extractParameterMarkerExpressionSegment(assignments);
+ parameterCount = parametersValueExpressions.size();
this.parameters = getParameters(parameters, parametersOffset);
}
- private int calculateParameterCount(final Collection<ExpressionSegment>
assignments) {
- int result = 0;
- for (ExpressionSegment each : assignments) {
- if (each instanceof ParameterMarkerExpressionSegment) {
- result++;
- }
- }
- return result;
- }
-
private List<ExpressionSegment> getValueExpressions(final
Collection<ExpressionSegment> assignments) {
List<ExpressionSegment> result = new ArrayList<>(assignments.size());
result.addAll(assignments);
@@ -80,7 +74,11 @@ public final class InsertValueContext {
*/
public Object getValue(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
- return valueExpression instanceof ParameterMarkerExpressionSegment ?
parameters.get(getParameterIndex(valueExpression)) :
((LiteralExpressionSegment) valueExpression).getLiterals();
+ if (parametersValueExpressions.contains(valueExpression)) {
+ return
parameters.get(parametersValueExpressions.indexOf(valueExpression));
+ } else {
+ return ((LiteralExpressionSegment) valueExpression).getLiterals();
+ }
}
/**
@@ -91,19 +89,6 @@ public final class InsertValueContext {
*/
public int getParameterIndex(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
- return getParameterIndex(valueExpression);
- }
-
- private int getParameterIndex(final ExpressionSegment valueExpression) {
- int result = 0;
- for (ExpressionSegment each : valueExpressions) {
- if (valueExpression == each) {
- return result;
- }
- if (each instanceof ParameterMarkerExpressionSegment) {
- result++;
- }
- }
- throw new IllegalArgumentException("Can not get parameter index.");
+ return parametersValueExpressions.indexOf(valueExpression);
}
}
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
index b70a105..e382586 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/main/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContext.java
@@ -19,6 +19,7 @@ package
org.apache.shardingsphere.infra.binder.segment.insert.values;
import lombok.Getter;
import lombok.ToString;
+import
org.apache.shardingsphere.infra.statement.InsertContextExpressSegmentUtil;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
@@ -38,6 +39,8 @@ public final class OnDuplicateUpdateContext {
private final int parameterCount;
private final List<ExpressionSegment> valueExpressions;
+
+ private final List<ParameterMarkerExpressionSegment>
parametersValueExpressions;
private final List<Object> parameters;
@@ -45,22 +48,13 @@ public final class OnDuplicateUpdateContext {
public OnDuplicateUpdateContext(final Collection<AssignmentSegment>
assignments, final List<Object> parameters, final int parametersOffset) {
List<ExpressionSegment> expressionSegments =
assignments.stream().map(AssignmentSegment::getValue).collect(Collectors.toList());
- parameterCount = calculateParameterCount(expressionSegments);
valueExpressions = getValueExpressions(expressionSegments);
+ parametersValueExpressions =
InsertContextExpressSegmentUtil.extractParameterMarkerExpressionSegment(expressionSegments);
+ parameterCount = parametersValueExpressions.size();
this.parameters = getParameters(parameters, parametersOffset);
columns = assignments.stream().map(assignment ->
assignment.getColumns().get(0)).collect(Collectors.toList());
}
- private int calculateParameterCount(final Collection<ExpressionSegment>
assignments) {
- int result = 0;
- for (ExpressionSegment each : assignments) {
- if (each instanceof ParameterMarkerExpressionSegment) {
- result++;
- }
- }
- return result;
- }
-
private List<ExpressionSegment> getValueExpressions(final
Collection<ExpressionSegment> assignments) {
List<ExpressionSegment> result = new ArrayList<>(assignments.size());
result.addAll(assignments);
@@ -84,20 +78,11 @@ public final class OnDuplicateUpdateContext {
*/
public Object getValue(final int index) {
ExpressionSegment valueExpression = valueExpressions.get(index);
- return valueExpression instanceof ParameterMarkerExpressionSegment ?
parameters.get(getParameterIndex(valueExpression)) :
((LiteralExpressionSegment) valueExpression).getLiterals();
- }
-
- private int getParameterIndex(final ExpressionSegment valueExpression) {
- int result = 0;
- for (ExpressionSegment each : valueExpressions) {
- if (valueExpression == each) {
- return result;
- }
- if (each instanceof ParameterMarkerExpressionSegment) {
- result++;
- }
+ if (parametersValueExpressions.contains(valueExpression)) {
+ return
parameters.get(parametersValueExpressions.indexOf(valueExpression));
+ } else {
+ return ((LiteralExpressionSegment) valueExpression).getLiterals();
}
- throw new IllegalArgumentException("Can not get parameter index.");
}
/**
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContextTest.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContextTest.java
index 28d05b2..cc824c8 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContextTest.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/InsertValueContextTest.java
@@ -17,13 +17,18 @@
package org.apache.shardingsphere.infra.binder.segment.insert.values;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ExpressionProjectionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import org.junit.Test;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
+import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
@@ -40,10 +45,6 @@ public final class InsertValueContextTest {
List<Object> parameters = Collections.emptyList();
int parametersOffset = 0;
InsertValueContext insertValueContext = new
InsertValueContext(assignments, parameters, parametersOffset);
- Method calculateParameterCountMethod =
InsertValueContext.class.getDeclaredMethod("calculateParameterCount",
Collection.class);
- calculateParameterCountMethod.setAccessible(true);
- int calculateParameterCountResult = (int)
calculateParameterCountMethod.invoke(insertValueContext, new Object[]
{assignments});
- assertThat(insertValueContext.getParameterCount(),
is(calculateParameterCountResult));
Method getValueExpressionsMethod =
InsertValueContext.class.getDeclaredMethod("getValueExpressions",
Collection.class);
getValueExpressionsMethod.setAccessible(true);
List<ExpressionSegment> getValueExpressionsResult =
(List<ExpressionSegment>) getValueExpressionsMethod.invoke(insertValueContext,
new Object[] {assignments});
@@ -82,4 +83,17 @@ public final class InsertValueContextTest {
private Collection<ExpressionSegment> makeLiteralExpressionSegment(final
Object literalObject) {
return Collections.singleton(new LiteralExpressionSegment(0, 10,
literalObject));
}
+
+ @Test
+ public void assertGetParameterCount() {
+ Collection<ExpressionSegment> assignments = Arrays.asList(
+ new LiteralExpressionSegment(0, 10, null),
+ new ExpressionProjectionSegment(0, 10, ""),
+ new ParameterMarkerExpressionSegment(0, 10, 5),
+ new BinaryOperationExpression(0, 0, new ColumnSegment(0, 0,
new IdentifierValue("")), new ParameterMarkerExpressionSegment(0, 10, 5), "=",
"")
+ );
+ List<Object> parameters = Arrays.asList("", "");
+ InsertValueContext insertValueContext = new
InsertValueContext(assignments, parameters, 0);
+ assertThat(insertValueContext.getParameterCount(), is(2));
+ }
}
diff --git
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContextTest.java
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContextTest.java
index afc7055..8bc72dc 100644
---
a/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContextTest.java
+++
b/shardingsphere-infra/shardingsphere-infra-binder/src/test/java/org/apache/shardingsphere/infra/binder/segment/insert/values/OnDuplicateUpdateContextTest.java
@@ -20,6 +20,7 @@ package
org.apache.shardingsphere.infra.binder.segment.insert.values;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.AssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.assignment.ColumnAssignmentSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.column.ColumnSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.LiteralExpressionSegment;
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
@@ -37,7 +38,6 @@ import java.util.List;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
public final class OnDuplicateUpdateContextTest {
@@ -48,10 +48,6 @@ public final class OnDuplicateUpdateContextTest {
List<Object> parameters = Collections.emptyList();
int parametersOffset = 0;
OnDuplicateUpdateContext onDuplicateUpdateContext = new
OnDuplicateUpdateContext(assignments, parameters, parametersOffset);
- Method calculateParameterCountMethod =
OnDuplicateUpdateContext.class.getDeclaredMethod("calculateParameterCount",
Collection.class);
- calculateParameterCountMethod.setAccessible(true);
- int calculateParameterCountResult = (int)
calculateParameterCountMethod.invoke(onDuplicateUpdateContext, new
Object[]{assignments});
- assertThat(onDuplicateUpdateContext.getParameterCount(),
is(calculateParameterCountResult));
Method getValueExpressionsMethod =
OnDuplicateUpdateContext.class.getDeclaredMethod("getValueExpressions",
Collection.class);
getValueExpressionsMethod.setAccessible(true);
List<ExpressionSegment> getValueExpressionsResult =
(List<ExpressionSegment>)
getValueExpressionsMethod.invoke(onDuplicateUpdateContext, new
Object[]{assignments});
@@ -99,6 +95,15 @@ public final class OnDuplicateUpdateContextTest {
AssignmentSegment assignmentSegment =
makeAssignmentSegment(parameterLiteralExpression);
return Collections.singleton(assignmentSegment);
}
+
+ private BinaryOperationExpression makeBinaryOperationExpression() {
+ int doesNotMatterIndex = 0;
+ String doesNotMatterColumnName = "columnNameStr";
+ String doesNotMatterColumnText = "columnNameStr=?";
+ ExpressionSegment left = new ColumnSegment(doesNotMatterIndex,
doesNotMatterIndex, new IdentifierValue(doesNotMatterColumnName));
+ ExpressionSegment right = new
ParameterMarkerExpressionSegment(doesNotMatterIndex, doesNotMatterIndex,
doesNotMatterIndex);
+ return new BinaryOperationExpression(doesNotMatterIndex,
doesNotMatterIndex, left, right, "=", doesNotMatterColumnText);
+ }
private AssignmentSegment makeAssignmentSegment(final
SimpleExpressionSegment expressionSegment) {
int doesNotMatterLexicalIndex = 0;
@@ -109,23 +114,15 @@ public final class OnDuplicateUpdateContextTest {
AssignmentSegment result = new
ColumnAssignmentSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex,
columnSegments, expressionSegment);
return result;
}
-
- @Test
- public void assertGetParameterIndex() throws NoSuchMethodException,
IllegalAccessException {
- Collection<AssignmentSegment> assignments = Collections.emptyList();
- List<Object> parameters = Collections.emptyList();
- int parametersOffset = 0;
- OnDuplicateUpdateContext onDuplicateUpdateContext = new
OnDuplicateUpdateContext(assignments, parameters, parametersOffset);
- Method getParameterIndexMethod =
OnDuplicateUpdateContext.class.getDeclaredMethod("getParameterIndex",
ExpressionSegment.class);
- getParameterIndexMethod.setAccessible(true);
- ParameterMarkerExpressionSegment notExistsExpressionSegment = new
ParameterMarkerExpressionSegment(0, 0, 0);
- Throwable targetException = null;
- try {
- getParameterIndexMethod.invoke(onDuplicateUpdateContext,
notExistsExpressionSegment);
- } catch (final InvocationTargetException ex) {
- targetException = ex.getTargetException();
- }
- assertTrue("expected throw IllegalArgumentException", targetException
instanceof IllegalArgumentException);
+
+ private AssignmentSegment makeAssignmentSegment(final
BinaryOperationExpression binaryOperationExpression) {
+ int doesNotMatterLexicalIndex = 0;
+ String doesNotMatterColumnName = "columnNameStr";
+ ColumnSegment column = new ColumnSegment(doesNotMatterLexicalIndex,
doesNotMatterLexicalIndex, new IdentifierValue(doesNotMatterColumnName));
+ List<ColumnSegment> columnSegments = new LinkedList<>();
+ columnSegments.add(column);
+ AssignmentSegment result = new
ColumnAssignmentSegment(doesNotMatterLexicalIndex, doesNotMatterLexicalIndex,
columnSegments, binaryOperationExpression);
+ return result;
}
@Test
@@ -137,4 +134,19 @@ public final class OnDuplicateUpdateContextTest {
ColumnSegment column = onDuplicateUpdateContext.getColumn(0);
assertThat(column,
is(assignments.iterator().next().getColumns().get(0)));
}
+
+ @Test
+ public void assertParameterCount() {
+ List<AssignmentSegment> assignments = Arrays.asList(
+ makeAssignmentSegment(makeBinaryOperationExpression()),
+ makeAssignmentSegment(new ParameterMarkerExpressionSegment(0,
10, 5)),
+ makeAssignmentSegment(new LiteralExpressionSegment(0, 10, new
Object()))
+ );
+ int doestNotMatterParametersOffset = 0;
+ String doesNotMatterParameterValue = "";
+ List<Object> parameters = Arrays.asList(doesNotMatterParameterValue,
doesNotMatterParameterValue);
+ OnDuplicateUpdateContext onDuplicateUpdateContext = new
OnDuplicateUpdateContext(assignments, parameters,
doestNotMatterParametersOffset);
+ assertThat(onDuplicateUpdateContext.getParameterCount(), is(2));
+ }
+
}
diff --git
a/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/statement/InsertContextExpressSegmentUtil.java
b/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/statement/InsertContextExpressSegmentUtil.java
new file mode 100644
index 0000000..ce65d1a
--- /dev/null
+++
b/shardingsphere-infra/shardingsphere-infra-common/src/main/java/org/apache/shardingsphere/infra/statement/InsertContextExpressSegmentUtil.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.shardingsphere.infra.statement;
+
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.BinaryOperationExpression;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.ExpressionSegment;
+import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.simple.ParameterMarkerExpressionSegment;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+/**
+ * Insert context expression segment util.
+ */
+public final class InsertContextExpressSegmentUtil {
+ /**
+ * Extract all ParameterMarkerExpressionSegment from ExpressionSegment
list.
+ *
+ * @param expressions ExpressionSegment list
+ * @return ParameterMarkerExpressionSegment list
+ */
+ public static List<ParameterMarkerExpressionSegment>
extractParameterMarkerExpressionSegment(final Collection<ExpressionSegment>
expressions) {
+ List<ParameterMarkerExpressionSegment> result = new ArrayList<>();
+ for (ExpressionSegment each : expressions) {
+ if (each instanceof ParameterMarkerExpressionSegment) {
+ result.add((ParameterMarkerExpressionSegment) each);
+ } else if (each instanceof BinaryOperationExpression) {
+ if (((BinaryOperationExpression) each).getLeft() instanceof
ParameterMarkerExpressionSegment) {
+ result.add((ParameterMarkerExpressionSegment)
((BinaryOperationExpression) each).getLeft());
+ }
+ if (((BinaryOperationExpression) each).getRight() instanceof
ParameterMarkerExpressionSegment) {
+ result.add((ParameterMarkerExpressionSegment)
((BinaryOperationExpression) each).getRight());
+ }
+ }
+ }
+ return result;
+ }
+}
diff --git
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ExpressionExtractUtil.java
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ExpressionExtractUtil.java
index fd1f146..e082a4c 100644
---
a/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ExpressionExtractUtil.java
+++
b/shardingsphere-sql-parser/shardingsphere-sql-parser-statement/src/main/java/org/apache/shardingsphere/sql/parser/sql/common/util/ExpressionExtractUtil.java
@@ -25,8 +25,8 @@ import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.expr.Expressi
import
org.apache.shardingsphere.sql.parser.sql.common.segment.dml.predicate.AndPredicate;
import java.util.Collection;
-import java.util.Collections;
import java.util.LinkedList;
+import java.util.Collections;
import java.util.Optional;
/**