This is an automated email from the ASF dual-hosted git repository.

asolimando pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/calcite.git


The following commit(s) were added to refs/heads/main by this push:
     new f01fb2a4ea [CALCITE-7203] IntersectToSemiJoinRule should compute once 
the join keys and reuse them to avoid duplicates
f01fb2a4ea is described below

commit f01fb2a4ea7d63eefd2b6613a5d88a9bd1c12f69
Author: Alessandro Solimando <[email protected]>
AuthorDate: Sat Sep 27 12:00:38 2025 +0200

    [CALCITE-7203] IntersectToSemiJoinRule should compute once the join keys 
and reuse them to avoid duplicates
---
 .../calcite/rel/rules/IntersectToSemiJoinRule.java | 59 ++++++++++++++--------
 .../org/apache/calcite/test/RelOptRulesTest.xml    | 18 ++++---
 core/src/test/resources/sql/planner.iq             | 18 +++----
 3 files changed, 56 insertions(+), 39 deletions(-)

diff --git 
a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java 
b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
index 66afcaa88d..06fbccd8a1 100644
--- 
a/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
+++ 
b/core/src/main/java/org/apache/calcite/rel/rules/IntersectToSemiJoinRule.java
@@ -66,20 +66,22 @@
  *
  * <p>Plan after conversion:
  * <pre>{@code
- * LogicalProject(ENAME=[CAST($0):VARCHAR])
- *   LogicalAggregate(group=[{0}])
- *     LogicalJoin(condition=[<=>(CAST($0):VARCHAR, CAST($1):VARCHAR)], 
joinType=[semi])
- *       LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[semi])
+ * LogicalAggregate(group=[{0}])
+ *   LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+ *     LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+ *       LogicalProject(ENAME=[CAST($0):VARCHAR])
  *         LogicalProject(ENAME=[$1])
  *           LogicalFilter(condition=[=($7, 10)])
  *             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ *       LogicalProject(ENAME=[CAST($0):VARCHAR])
  *         LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
  *           LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
  *             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+ *     LogicalProject(ENAME=[CAST($0):VARCHAR])
  *       LogicalProject(ENAME=[$1])
  *         LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
  * }</pre>
- */
+*/
 @Value.Enclosing
 public class IntersectToSemiJoinRule
     extends RelRule<IntersectToSemiJoinRule.Config>
@@ -108,37 +110,50 @@ protected IntersectToSemiJoinRule(Config config) {
 
     final RelDataType leastRowType = intersect.getRowType();
     RelNode current = inputs.get(0);
-    builder.push(current);
 
     for (int i = 1; i < inputs.size(); i++) {
       RelNode next = inputs.get(i);
-      List<RexNode> conditions = new ArrayList<>();
-      int fieldCount = current.getRowType().getFieldCount();
 
-      for (int j = 0; j < fieldCount; j++) {
-        RelDataType leftFieldType = 
current.getRowType().getFieldList().get(j).getType();
-        RelDataType rightFieldType = 
next.getRowType().getFieldList().get(j).getType();
-        RelDataType leastFieldType = 
leastRowType.getFieldList().get(j).getType();
+      // cast columns of the join inputs to the least types (global)
+      final RelNode leftCasted = projectJoinInput(builder, leastRowType, 
current);
+      final RelNode rightCasted = projectJoinInput(builder, leastRowType, 
next);
+      builder.push(leftCasted).push(rightCasted);
 
-        conditions.add(
+      // compute the join condition over plain fields from the projections of 
left/right inputs
+      final int fieldCount = leastRowType.getFieldCount();
+      final List<RexNode> joinPredicates = new ArrayList<>(fieldCount);
+      for (int j = 0; j < fieldCount; j++) {
+        joinPredicates.add(
             builder.isNotDistinctFrom(
-                rexBuilder.makeCast(leastFieldType,
-                    rexBuilder.makeInputRef(leftFieldType, j)),
-                rexBuilder.makeCast(leastFieldType,
-                    rexBuilder.makeInputRef(rightFieldType, j + fieldCount))));
+            builder.field(2, 0, j),
+            builder.field(2, 1, j)));
       }
-      RexNode condition = RexUtil.composeConjunction(rexBuilder, conditions);
 
-      builder.push(next)
-          .join(JoinRelType.SEMI, condition);
+      final RexNode condition = RexUtil.composeConjunction(rexBuilder, 
joinPredicates);
+      builder.join(JoinRelType.SEMI, condition);
       current = builder.peek();
     }
 
-    builder.distinct()
-        .convert(leastRowType, true);
+    builder.distinct().convert(leastRowType, true);
     call.transformTo(builder.build());
   }
 
+  private RelNode projectJoinInput(
+      RelBuilder builder, RelDataType leastRowType, RelNode joinInput) {
+    builder.push(joinInput);
+
+    final int fieldCount = joinInput.getRowType().getFieldCount();
+    final List<String> names = leastRowType.getFieldNames();
+    final List<RexNode> joinKeys = new ArrayList<>(fieldCount);
+    final RexBuilder rexBuilder = builder.getRexBuilder();
+    for (int j = 0; j < fieldCount; j++) {
+      final RelDataType leastType = 
leastRowType.getFieldList().get(j).getType();
+      joinKeys.add(rexBuilder.makeCast(leastType, builder.field(j)));
+    }
+
+    return builder.project(joinKeys, names).build();
+  }
+
   /** Rule configuration. */
   @Value.Immutable
   public interface Config extends RelRule.Config {
diff --git 
a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml 
b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
index 0a4acbaa3e..bd9de19292 100644
--- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
+++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml
@@ -6941,7 +6941,7 @@ LogicalIntersect(all=[false])
     <Resource name="planAfter">
       <![CDATA[
 LogicalAggregate(group=[{0}])
-  LogicalJoin(condition=[=($0, $1)], joinType=[semi])
+  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
     LogicalProject(ENAME=[$1])
       LogicalFilter(condition=[=($7, 10)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -6975,16 +6975,18 @@ LogicalIntersect(all=[false])
     </Resource>
     <Resource name="planAfter">
       <![CDATA[
-LogicalProject(ENAME=[CAST($0):VARCHAR])
-  LogicalAggregate(group=[{0}])
-    LogicalJoin(condition=[IS NOT DISTINCT FROM(CAST($0):VARCHAR, 
CAST($1):VARCHAR)], joinType=[semi])
-      LogicalJoin(condition=[=(CAST($0):VARCHAR, $1)], joinType=[semi])
+LogicalAggregate(group=[{0}])
+  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+    LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
+      LogicalProject(ENAME=[CAST($0):VARCHAR])
         LogicalProject(ENAME=[$1])
           LogicalFilter(condition=[=($7, 10)])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+      LogicalProject(ENAME=[CAST($0):VARCHAR])
         LogicalProject(DEPTNO=[CAST($7):VARCHAR NOT NULL])
           LogicalFilter(condition=[OR(=($1, 'a'), =($1, 'b'))])
             LogicalTableScan(table=[[CATALOG, SALES, EMP]])
+    LogicalProject(ENAME=[CAST($0):VARCHAR])
       LogicalProject(ENAME=[$1])
         LogicalTableScan(table=[[CATALOG, SALES, EMPNULLABLES]])
 ]]>
@@ -7018,7 +7020,7 @@ LogicalIntersect(all=[true])
       <![CDATA[
 LogicalIntersect(all=[true])
   LogicalAggregate(group=[{0}])
-    LogicalJoin(condition=[=($0, $1)], joinType=[semi])
+    LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
       LogicalProject(ENAME=[$1])
         LogicalFilter(condition=[=($7, 10)])
           LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -7052,7 +7054,7 @@ LogicalIntersect(all=[false])
     <Resource name="planAfter">
       <![CDATA[
 LogicalAggregate(group=[{0}])
-  LogicalJoin(condition=[=($0, $1)], joinType=[semi])
+  LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $1)], joinType=[semi])
     LogicalProject(ENAME=[$1])
       LogicalFilter(condition=[=($7, 10)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
@@ -7083,7 +7085,7 @@ LogicalIntersect(all=[false])
     <Resource name="planAfter">
       <![CDATA[
 LogicalAggregate(group=[{0, 1}])
-  LogicalJoin(condition=[AND(=($0, $2), =($1, $3))], joinType=[semi])
+  LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS NOT DISTINCT 
FROM($1, $3))], joinType=[semi])
     LogicalProject(DEPTNO=[$7], ENAME=[$1])
       LogicalFilter(condition=[=($7, 10)])
         LogicalTableScan(table=[[CATALOG, SALES, EMP]])
diff --git a/core/src/test/resources/sql/planner.iq 
b/core/src/test/resources/sql/planner.iq
index e9e48f6c7a..787ec7fb1a 100644
--- a/core/src/test/resources/sql/planner.iq
+++ b/core/src/test/resources/sql/planner.iq
@@ -54,7 +54,7 @@ select * from t as t2 where t2.i > 0;
 
 !ok
 
-EnumerableHashJoin(condition=[=($0, $1)], joinType=[semi])
+EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[semi])
   EnumerableValues(tuples=[[{ 0 }, { 1 }]])
   EnumerableCalc(expr#0=[{inputs}], expr#1=[0], expr#2=[>($t0, $t1)], 
EXPR$0=[$t0], $condition=[$t2])
     EnumerableValues(tuples=[[{ 0 }, { 1 }]])
@@ -74,7 +74,7 @@ select * from t as t2 where t2.i > 0;
 
 !ok
 
-EnumerableHashJoin(condition=[=($0, $1)], joinType=[semi])
+EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[semi])
   EnumerableValues(tuples=[[{ 0 }, { 1 }]])
   EnumerableCalc(expr#0=[{inputs}], expr#1=[0], expr#2=[>($t0, $t1)], 
EXPR$0=[$t0], $condition=[$t2])
     EnumerableValues(tuples=[[{ 0 }, { 1 }]])
@@ -166,16 +166,16 @@ select a from (values (1.0), (4.0), (null)) as t3 (a);
 
 !ok
 
-EnumerableCalc(expr#0..1=[{inputs}], expr#2=[CAST($t0):DECIMAL(11, 1)], 
A=[$t2])
-  EnumerableHashJoin(condition=[=($1, $3)], joinType=[semi])
-    EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], 
proj#0..1=[{exprs}])
+EnumerableAggregate(group=[{0}])
+  EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[semi])
+    EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], 
A=[$t1])
       EnumerableAggregate(group=[{0}])
-        EnumerableHashJoin(condition=[=($1, $3)], joinType=[semi])
-          EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) 
NOT NULL], A=[$t1], A0=[$t1])
+        EnumerableNestedLoopJoin(condition=[IS NOT DISTINCT FROM($0, $1)], 
joinType=[semi])
+          EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) 
NOT NULL], A=[$t1])
             EnumerableValues(tuples=[[{ 1.0 }, { 2.0 }, { 3.0 }, { 4.0 }, { 
5.0 }]])
-          EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) 
NOT NULL], A=[$t1], A0=[$t1])
+          EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1) 
NOT NULL], A=[$t1])
             EnumerableValues(tuples=[[{ 1 }, { 2 }]])
-    EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], 
A=[$t1], A0=[$t1])
+    EnumerableCalc(expr#0=[{inputs}], expr#1=[CAST($t0):DECIMAL(11, 1)], 
A=[$t1])
       EnumerableValues(tuples=[[{ 1.0 }, { 4.0 }, { null }]])
 !plan
 !set planner-rules original

Reply via email to