This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 509d865760 [feature](Nereids): convert CaseWhen to If (#23040)
509d865760 is described below
commit 509d865760ffc9bb82e6a9562391f40fce64e3ca
Author: jakevin <[email protected]>
AuthorDate: Wed Aug 30 15:47:29 2023 +0800
[feature](Nereids): convert CaseWhen to If (#23040)
Add a rule to optimize CASE WHEN expression.
Rewrite rule to convert CASE WHEN to IF.
For example:
CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
---
.../rules/expression/ExpressionNormalization.java | 6 ---
.../rules/expression/rules/CaseWhenToIf.java | 49 ++++++++++++++++++++++
.../rules/expression/ExpressionRewriteTest.java | 20 ++++-----
3 files changed, 59 insertions(+), 16 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index bd234204b2..6cff3553b4 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -29,7 +29,6 @@ import
org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
import
org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter;
-import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.collect.ImmutableList;
@@ -60,10 +59,5 @@ public class ExpressionNormalization extends
ExpressionRewrite {
public ExpressionNormalization() {
super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES));
}
-
- @Override
- public Expression rewrite(Expression expression, ExpressionRewriteContext
context) {
- return super.rewrite(expression, context);
- }
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
new file mode 100644
index 0000000000..6372338406
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+
+/**
+ * Rewrite rule to convert CASE WHEN to IF.
+ * For example:
+ * CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
+ */
+public class CaseWhenToIf extends AbstractExpressionRewriteRule {
+
+ public static CaseWhenToIf INSTANCE = new CaseWhenToIf();
+
+ @Override
+ public Expression visitCaseWhen(CaseWhen caseWhen,
ExpressionRewriteContext context) {
+ Expression expr = caseWhen;
+ if (caseWhen.getWhenClauses().size() == 1) {
+ WhenClause whenClause = caseWhen.getWhenClauses().get(0);
+ Expression operand = whenClause.getOperand();
+ Expression result = whenClause.getResult();
+ expr = new If(operand, result,
caseWhen.getDefaultValue().orElse(new NullLiteral(result.getDataType())));
+ }
+ // TODO: traverse expr in CASE WHEN / If.
+ return expr;
+ }
+}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 5b3cdd7dd1..1010e7df27 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -59,10 +59,10 @@ import java.math.BigDecimal;
/**
* all expr rewrite rule test case.
*/
-public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
+class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
@Test
- public void testNotRewrite() {
+ void testNotRewrite() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
assertRewrite("not x", "not x");
@@ -87,7 +87,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testNormalizeExpressionRewrite() {
+ void testNormalizeExpressionRewrite() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
assertRewrite("1 = 1", "1 = 1");
@@ -99,7 +99,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testDistinctPredicatesRewrite() {
+ void testDistinctPredicatesRewrite() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
assertRewrite("a = 1", "a = 1");
@@ -111,7 +111,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testExtractCommonFactorRewrite() {
+ void testExtractCommonFactorRewrite() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
assertRewrite("a", "a");
@@ -164,7 +164,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testInPredicateToEqualToRule() {
+ void testInPredicateToEqualToRule() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
assertRewrite("a in (1)", "a = 1");
@@ -180,14 +180,14 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testInPredicateDedup() {
+ void testInPredicateDedup() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE));
assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)");
}
@Test
- public void testSimplifyCastRule() {
+ void testSimplifyCastRule() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
// deduplicate
@@ -219,7 +219,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testSimplifyComparisonPredicateRule() {
+ void testSimplifyComparisonPredicateRule() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE));
Expression dtv2 = new DateTimeV2Literal(1, 1, 1, 1, 1, 1, 0);
@@ -271,7 +271,7 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
}
@Test
- public void testSimplifyDecimalV3Comparison() {
+ void testSimplifyDecimalV3Comparison() {
executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
// do rewrite
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]