This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 3047d7dd07 [fix](Nereids) fix or to in rule (#23940)
3047d7dd07 is described below
commit 3047d7dd078a38815866032030f9d8262fab927f
Author: 谢健 <[email protected]>
AuthorDate: Wed Sep 6 14:58:20 2023 +0800
[fix](Nereids) fix or to in rule (#23940)
or expression context can't propagation cross or expression.
for example:
```
select (a = 1 or a = 2 or a = 3) + (a = 4 or a = 5 or a = 6)
= select a in [1, 2, 3] + a in [4,5,6]
!= select a in [1, 2, 3] + a in [1, 2, 3, 4, 5, 6]
```
---
.../nereids/rules/expression/rules/OrToIn.java | 55 +++++++++-------------
.../doris/nereids/rules/rewrite/OrToInTest.java | 27 +++++++++--
2 files changed, 44 insertions(+), 38 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
index a54d5f5369..aaa077d199 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
@@ -19,13 +19,11 @@ package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteRule;
-import org.apache.doris.nereids.rules.expression.rules.OrToIn.OrToInContext;
-import org.apache.doris.nereids.trees.expressions.And;
-import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
@@ -57,7 +55,7 @@ import java.util.Set;
* adding any additional rule-specific fields to the default
ExpressionRewriteContext. However, the entire expression
* rewrite framework always passes an ExpressionRewriteContext of type context
to all rules.
*/
-public class OrToIn extends DefaultExpressionRewriter<OrToInContext> implements
+public class OrToIn extends
DefaultExpressionRewriter<ExpressionRewriteContext> implements
ExpressionRewriteRule<ExpressionRewriteContext> {
public static final OrToIn INSTANCE = new OrToIn();
@@ -66,25 +64,20 @@ public class OrToIn extends
DefaultExpressionRewriter<OrToInContext> implements
@Override
public Expression rewrite(Expression expr, ExpressionRewriteContext ctx) {
- return expr.accept(this, new OrToInContext());
+ return expr.accept(this, null);
}
@Override
- public Expression visitCompoundPredicate(CompoundPredicate
compoundPredicate, OrToInContext context) {
- if (compoundPredicate instanceof And) {
- return
compoundPredicate.withChildren(compoundPredicate.child(0).accept(new OrToIn(),
- new OrToInContext()),
- compoundPredicate.child(1).accept(new OrToIn(),
- new OrToInContext()));
- }
- List<Expression> expressions =
ExpressionUtils.extractDisjunction(compoundPredicate);
+ public Expression visitOr(Or or, ExpressionRewriteContext ctx) {
+ Map<NamedExpression, Set<Literal>> slotNameToLiteral = new HashMap<>();
+ List<Expression> expressions = ExpressionUtils.extractDisjunction(or);
for (Expression expression : expressions) {
if (expression instanceof EqualTo) {
- addSlotToLiteralMap((EqualTo) expression, context);
+ addSlotToLiteralMap((EqualTo) expression, slotNameToLiteral);
}
}
List<Expression> rewrittenOr = new ArrayList<>();
- for (Map.Entry<NamedExpression, Set<Literal>> entry :
context.slotNameToLiteral.entrySet()) {
+ for (Map.Entry<NamedExpression, Set<Literal>> entry :
slotNameToLiteral.entrySet()) {
Set<Literal> literals = entry.getValue();
if (literals.size() >= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
InPredicate inPredicate = new InPredicate(entry.getKey(),
ImmutableList.copyOf(entry.getValue()));
@@ -92,26 +85,26 @@ public class OrToIn extends
DefaultExpressionRewriter<OrToInContext> implements
}
}
for (Expression expression : expressions) {
- if (!ableToConvertToIn(expression, context)) {
- rewrittenOr.add(expression);
+ if (!ableToConvertToIn(expression, slotNameToLiteral)) {
+ rewrittenOr.add(expression.accept(this, null));
}
}
return ExpressionUtils.or(rewrittenOr);
}
- private void addSlotToLiteralMap(EqualTo equal, OrToInContext context) {
+ private void addSlotToLiteralMap(EqualTo equal, Map<NamedExpression,
Set<Literal>> slotNameToLiteral) {
Expression left = equal.left();
Expression right = equal.right();
if (left instanceof NamedExpression && right instanceof Literal) {
- addSlotToLiteral((NamedExpression) left, (Literal) right, context);
+ addSlotToLiteral((NamedExpression) left, (Literal) right,
slotNameToLiteral);
}
if (right instanceof NamedExpression && left instanceof Literal) {
- addSlotToLiteral((NamedExpression) right, (Literal) left, context);
+ addSlotToLiteral((NamedExpression) right, (Literal) left,
slotNameToLiteral);
}
}
- private boolean ableToConvertToIn(Expression expression, OrToInContext
context) {
+ private boolean ableToConvertToIn(Expression expression,
Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
if (!(expression instanceof EqualTo)) {
return false;
}
@@ -126,24 +119,18 @@ public class OrToIn extends
DefaultExpressionRewriter<OrToInContext> implements
namedExpression = (NamedExpression) right;
}
return namedExpression != null
- && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression,
context)
+ && findSizeOfLiteralThatEqualToSameSlotInOr(namedExpression,
slotNameToLiteral)
>= REWRITE_OR_TO_IN_PREDICATE_THRESHOLD;
}
- public void addSlotToLiteral(NamedExpression namedExpression, Literal
literal, OrToInContext context) {
- Set<Literal> literals =
context.slotNameToLiteral.computeIfAbsent(namedExpression, k -> new
HashSet<>());
+ public void addSlotToLiteral(NamedExpression namedExpression, Literal
literal,
+ Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
+ Set<Literal> literals =
slotNameToLiteral.computeIfAbsent(namedExpression, k -> new HashSet<>());
literals.add(literal);
}
- public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression
namedExpression, OrToInContext context) {
- return context.slotNameToLiteral.getOrDefault(namedExpression,
Collections.emptySet()).size();
- }
-
- /**
- * Context of OrToIn
- */
- public static class OrToInContext {
- public final Map<NamedExpression, Set<Literal>> slotNameToLiteral =
new HashMap<>();
-
+ public int findSizeOfLiteralThatEqualToSameSlotInOr(NamedExpression
namedExpression,
+ Map<NamedExpression, Set<Literal>> slotNameToLiteral) {
+ return slotNameToLiteral.getOrDefault(namedExpression,
Collections.emptySet()).size();
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
index 651c330c55..f77a66dd88 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/OrToInTest.java
@@ -33,10 +33,10 @@ import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Set;
-public class OrToInTest extends ExpressionRewriteTestHelper {
+class OrToInTest extends ExpressionRewriteTestHelper {
@Test
- public void test1() {
+ void test1() {
String expr = "col1 = 1 or col1 = 2 or col1 = 3 and (col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new
ExpressionRewriteContext(null));
@@ -59,7 +59,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
}
@Test
- public void test2() {
+ void test2() {
String expr = "col1 = 1 and col1 = 3 and col2 = 3 or col2 = 4";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new
ExpressionRewriteContext(null));
@@ -68,7 +68,7 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
}
@Test
- public void test3() {
+ void test3() {
String expr = "(col1 = 1 or col1 = 2) and (col2 = 3 or col2 = 4)";
Expression expression = PARSER.parseExpression(expr);
Expression rewritten = new OrToIn().rewrite(expression, new
ExpressionRewriteContext(null));
@@ -90,4 +90,23 @@ public class OrToInTest extends ExpressionRewriteTestHelper {
}
}
+ @Test
+ void test4() {
+ String expr = "case when col = 1 or col = 2 or col = 3 then 1"
+ + " when col = 4 or col = 5 or col = 6 then 1 else 0
end";
+ Expression expression = PARSER.parseExpression(expr);
+ Expression rewritten = new OrToIn().rewrite(expression, new
ExpressionRewriteContext(null));
+ Assertions.assertEquals("CASE WHEN col IN (1, 2, 3) THEN 1 WHEN col IN
(4, 5, 6) THEN 1 ELSE 0 END",
+ rewritten.toSql());
+ }
+
+ @Test
+ void test5() {
+ String expr = "col = 1 or (col = 2 and (col = 3 or col = 4 or col =
5))";
+ Expression expression = PARSER.parseExpression(expr);
+ Expression rewritten = new OrToIn().rewrite(expression, new
ExpressionRewriteContext(null));
+ Assertions.assertEquals("((col = 1) OR ((col = 2) AND col IN (3, 4,
5)))",
+ rewritten.toSql());
+ }
+
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]