This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 509d865760 [feature](Nereids): convert CaseWhen to If (#23040)
509d865760 is described below

commit 509d865760ffc9bb82e6a9562391f40fce64e3ca
Author: jakevin <[email protected]>
AuthorDate: Wed Aug 30 15:47:29 2023 +0800

    [feature](Nereids): convert CaseWhen to If (#23040)
    
    Add a rule to optimize CASE WHEN expression.
    Rewrite rule to convert CASE WHEN to IF.
    
    For example:
    CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
---
 .../rules/expression/ExpressionNormalization.java  |  6 ---
 .../rules/expression/rules/CaseWhenToIf.java       | 49 ++++++++++++++++++++++
 .../rules/expression/ExpressionRewriteTest.java    | 20 ++++-----
 3 files changed, 59 insertions(+), 16 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index bd234204b2..6cff3553b4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -29,7 +29,6 @@ import 
org.apache.doris.nereids.rules.expression.rules.SimplifyArithmeticRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
 import 
org.apache.doris.nereids.rules.expression.rules.SupportJavaDateFormatter;
-import org.apache.doris.nereids.trees.expressions.Expression;
 
 import com.google.common.collect.ImmutableList;
 
@@ -60,10 +59,5 @@ public class ExpressionNormalization extends 
ExpressionRewrite {
     public ExpressionNormalization() {
         super(new ExpressionRuleExecutor(NORMALIZE_REWRITE_RULES));
     }
-
-    @Override
-    public Expression rewrite(Expression expression, ExpressionRewriteContext 
context) {
-        return super.rewrite(expression, context);
-    }
 }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
new file mode 100644
index 0000000000..6372338406
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/CaseWhenToIf.java
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+
+/**
+ * Rewrite rule to convert CASE WHEN to IF.
+ * For example:
+ * CASE WHEN a > 1 THEN 1 ELSE 0 END -> IF(a > 1, 1, 0)
+ */
+public class CaseWhenToIf extends AbstractExpressionRewriteRule {
+
+    public static CaseWhenToIf INSTANCE = new CaseWhenToIf();
+
+    @Override
+    public Expression visitCaseWhen(CaseWhen caseWhen, 
ExpressionRewriteContext context) {
+        Expression expr = caseWhen;
+        if (caseWhen.getWhenClauses().size() == 1) {
+            WhenClause whenClause = caseWhen.getWhenClauses().get(0);
+            Expression operand = whenClause.getOperand();
+            Expression result = whenClause.getResult();
+            expr = new If(operand, result, 
caseWhen.getDefaultValue().orElse(new NullLiteral(result.getDataType())));
+        }
+        // TODO: traverse expr in CASE WHEN / If.
+        return expr;
+    }
+}
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 5b3cdd7dd1..1010e7df27 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -59,10 +59,10 @@ import java.math.BigDecimal;
 /**
  * all expr rewrite rule test case.
  */
-public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
+class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
 
     @Test
-    public void testNotRewrite() {
+    void testNotRewrite() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyNotExprRule.INSTANCE));
 
         assertRewrite("not x", "not x");
@@ -87,7 +87,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testNormalizeExpressionRewrite() {
+    void testNormalizeExpressionRewrite() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(NormalizeBinaryPredicatesRule.INSTANCE));
 
         assertRewrite("1 = 1", "1 = 1");
@@ -99,7 +99,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testDistinctPredicatesRewrite() {
+    void testDistinctPredicatesRewrite() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(DistinctPredicatesRule.INSTANCE));
 
         assertRewrite("a = 1", "a = 1");
@@ -111,7 +111,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testExtractCommonFactorRewrite() {
+    void testExtractCommonFactorRewrite() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(ExtractCommonFactorRule.INSTANCE));
 
         assertRewrite("a", "a");
@@ -164,7 +164,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testInPredicateToEqualToRule() {
+    void testInPredicateToEqualToRule() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(InPredicateToEqualToRule.INSTANCE));
 
         assertRewrite("a in (1)", "a = 1");
@@ -180,14 +180,14 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testInPredicateDedup() {
+    void testInPredicateDedup() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(InPredicateDedup.INSTANCE));
 
         assertRewrite("a in (1, 2, 1, 2)", "a in (1, 2)");
     }
 
     @Test
-    public void testSimplifyCastRule() {
+    void testSimplifyCastRule() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
 
         // deduplicate
@@ -219,7 +219,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testSimplifyComparisonPredicateRule() {
+    void testSimplifyComparisonPredicateRule() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE, 
SimplifyComparisonPredicate.INSTANCE));
 
         Expression dtv2 = new DateTimeV2Literal(1, 1, 1, 1, 1, 1, 0);
@@ -271,7 +271,7 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
     }
 
     @Test
-    public void testSimplifyDecimalV3Comparison() {
+    void testSimplifyDecimalV3Comparison() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
 
         // do rewrite


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to