This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 63b170251e [fix](nereids)cast filter and join conjunct's return type
to boolean (#21434)
63b170251e is described below
commit 63b170251eef08091a836e32505faf3a74314e74
Author: starocean999 <[email protected]>
AuthorDate: Mon Jul 3 17:22:46 2023 +0800
[fix](nereids)cast filter and join conjunct's return type to boolean
(#21434)
---
.../doris/nereids/jobs/executor/Rewriter.java | 2 +
.../org/apache/doris/nereids/rules/RuleType.java | 1 +
.../rules/rewrite/AdjustConjunctsReturnType.java | 70 ++++++++++++++++++++
.../suites/nereids_p0/datatype/test_cast.groovy | 77 ++++++++++++++++++++++
4 files changed, 150 insertions(+)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index ba3d71d896..970c5009f9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -30,6 +30,7 @@ import
org.apache.doris.nereids.rules.expression.CheckLegalityAfterRewrite;
import org.apache.doris.nereids.rules.expression.ExpressionNormalization;
import org.apache.doris.nereids.rules.expression.ExpressionOptimization;
import org.apache.doris.nereids.rules.expression.ExpressionRewrite;
+import org.apache.doris.nereids.rules.rewrite.AdjustConjunctsReturnType;
import org.apache.doris.nereids.rules.rewrite.AdjustNullable;
import
org.apache.doris.nereids.rules.rewrite.AggScalarSubQueryToWindowFunction;
import org.apache.doris.nereids.rules.rewrite.BuildAggForUnion;
@@ -280,6 +281,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
new PushdownFilterThroughProject(),
new MergeProjects()
),
+ custom(RuleType.ADJUST_CONJUNCTS_RETURN_TYPE,
AdjustConjunctsReturnType::new),
custom(RuleType.ADJUST_NULLABLE, AdjustNullable::new),
bottomUp(
new
ExpressionRewrite(CheckLegalityAfterRewrite.INSTANCE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index b7e75c3659..d70b45d1f0 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -228,6 +228,7 @@ public enum RuleType {
PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
+ ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),
// ensure having project on the top join
ENSURE_PROJECT_ON_TOP_JOIN(RuleTypeClass.REWRITE),
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java
new file mode 100644
index 0000000000..096799bec8
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustConjunctsReturnType.java
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.jobs.JobContext;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
+import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
+import org.apache.doris.nereids.types.BooleanType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * We need this rule to cast all filter and join conjunct's return type to
boolean after rewrite.
+ */
+public class AdjustConjunctsReturnType extends DefaultPlanRewriter<Void>
implements CustomRewriter {
+
+ @Override
+ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+ return plan.accept(this, null);
+ }
+
+ @Override
+ public Plan visit(Plan plan, Void context) {
+ return (LogicalPlan) super.visit(plan, context);
+ }
+
+ @Override
+ public Plan visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void
context) {
+ filter = (LogicalFilter<? extends Plan>) super.visit(filter, context);
+ Set<Expression> conjuncts = filter.getConjuncts().stream()
+ .map(expr -> TypeCoercionUtils.castIfNotSameType(expr,
BooleanType.INSTANCE))
+ .collect(Collectors.toSet());
+ return filter.withConjuncts(conjuncts);
+ }
+
+ @Override
+ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan>
join, Void context) {
+ join = (LogicalJoin<? extends Plan, ? extends Plan>) super.visit(join,
context);
+ List<Expression> hashConjuncts = join.getHashJoinConjuncts().stream()
+ .map(expr -> TypeCoercionUtils.castIfNotSameType(expr,
BooleanType.INSTANCE))
+ .collect(Collectors.toList());
+ List<Expression> otherConjuncts = join.getOtherJoinConjuncts().stream()
+ .map(expr -> TypeCoercionUtils.castIfNotSameType(expr,
BooleanType.INSTANCE))
+ .collect(Collectors.toList());
+ return join.withJoinConjuncts(hashConjuncts, otherConjuncts);
+ }
+}
diff --git a/regression-test/suites/nereids_p0/datatype/test_cast.groovy
b/regression-test/suites/nereids_p0/datatype/test_cast.groovy
new file mode 100644
index 0000000000..f9b6ee6e35
--- /dev/null
+++ b/regression-test/suites/nereids_p0/datatype/test_cast.groovy
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast") {
+ sql 'set enable_nereids_planner=true'
+ sql 'set enable_fallback_to_original_planner=false'
+
+ def tbl = "test_cast"
+
+ sql """ DROP TABLE IF EXISTS ${tbl}"""
+ sql """
+ CREATE TABLE IF NOT EXISTS ${tbl} (
+ `k0` int
+ )
+ DISTRIBUTED BY HASH(`k0`) BUCKETS 5 properties("replication_num" = "1")
+ """
+ sql """ INSERT INTO ${tbl} VALUES (101);"""
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then 1 else 0 end"
+ result([[101]])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then 12 else 0 end"
+ result([[101]])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then -12 else 0 end"
+ result([[101]])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then 0 else 1 end"
+ result([])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 != 101 then 0 else 1 end"
+ result([[101]])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then '1' else 0 end"
+ result([[101]])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then '12' else 0
end"
+ result([])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then 'false' else 0
end"
+ result([])
+ }
+
+ test {
+ sql "select * from ${tbl} where case when k0 = 101 then 'true' else 1
end"
+ result([[101]])
+ }
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]