This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 28ef1b29daf [fix](fe) Prevent cast project pushdown through union 
distinct (#64080)
28ef1b29daf is described below

commit 28ef1b29dafd2b5576910cbfc28a6558c44e7434
Author: morrySnow <[email protected]>
AuthorDate: Tue Jun 16 15:22:44 2026 +0800

    [fix](fe) Prevent cast project pushdown through union distinct (#64080)
    
    ### What problem does this PR solve?
    
    Related PR: #21168
    
    Problem Summary: PushProjectThroughUnion previously pushed cast
    projections below UNION DISTINCT without checking whether the cast
    preserved distinctness. For non-injective casts, distinct was then
    evaluated after the cast instead of before the cast, which could
    incorrectly collapse rows that were distinct in the original union
    output. This PR adds DataType.canSafetyCastTo to model casts that are
    safe for this rewrite, uses it to allow cast pushdown through UNION
    DISTINCT only when the cast is distinctness-preserving, and keeps UNION
    ALL pushdown unchanged. The implementation also covers primitive and
    complex types, including length-insensitive string-like casts, and fixes
    edge cases found during review for integral, map, and struct types.
---
 .../rules/rewrite/PushProjectThroughUnion.java     | 33 +++++++++--
 .../org/apache/doris/nereids/types/ArrayType.java  |  9 +++
 .../apache/doris/nereids/types/BooleanType.java    |  8 +++
 .../org/apache/doris/nereids/types/DataType.java   |  4 ++
 .../apache/doris/nereids/types/DateTimeType.java   |  9 +++
 .../apache/doris/nereids/types/DateTimeV2Type.java | 13 +++++
 .../org/apache/doris/nereids/types/DateType.java   |  6 ++
 .../apache/doris/nereids/types/DecimalV2Type.java  | 14 +++++
 .../apache/doris/nereids/types/DecimalV3Type.java  | 14 +++++
 .../org/apache/doris/nereids/types/MapType.java    | 10 ++++
 .../org/apache/doris/nereids/types/StructType.java | 18 ++++++
 .../doris/nereids/types/TimeStampTzType.java       | 10 ++++
 .../org/apache/doris/nereids/types/TimeV2Type.java | 10 ++++
 .../apache/doris/nereids/types/VariantType.java    |  5 ++
 .../nereids/types/coercion/CharacterType.java      |  5 ++
 .../doris/nereids/types/coercion/IntegralType.java | 15 +++++
 .../apache/doris/nereids/util/ExpressionUtils.java | 14 +++++
 .../rules/rewrite/PushProjectThroughUnionTest.java | 57 +++++++++++++++++++
 .../apache/doris/nereids/types/DataTypeTest.java   | 66 ++++++++++++++++++++++
 .../data/nereids_syntax_p0/set_operation.out       |  9 ++-
 .../suites/nereids_syntax_p0/set_operation.groovy  | 21 +++++++
 21 files changed, 343 insertions(+), 7 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
index dfaef55b56c..87b239556ec 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
@@ -20,12 +20,14 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
@@ -35,9 +37,11 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * this rule push down the project through union to let MergeUnion could do 
better
@@ -56,14 +60,31 @@ public class PushProjectThroughUnion extends 
OneRewriteRuleFactory {
 
     /** canPushProject */
     public static boolean canPushProject(List<NamedExpression> projects, 
LogicalSetOperation logicalSetOperation) {
-        return projects.size() == logicalSetOperation.getOutput().size() && 
projects.stream().allMatch(e -> {
-            if (e instanceof SlotReference) {
-                return true;
+        if (projects.size() != logicalSetOperation.getOutput().size()) {
+            return false;
+        }
+        boolean isAll = 
logicalSetOperation.getQualifier().equals(Qualifier.ALL);
+        Set<ExprId> projectInputExprIds = 
Sets.newHashSetWithExpectedSize(projects.size());
+        for (NamedExpression project : projects) {
+            Expression input;
+            if (project instanceof SlotReference) {
+                input = project;
+            } else if (isAll) {
+                input = 
ExpressionUtils.getExpressionCoveredByCast(project.child(0));
             } else {
-                Expression expr = 
ExpressionUtils.getExpressionCoveredByCast(e.child(0));
-                return expr instanceof SlotReference;
+                input = 
ExpressionUtils.getExpressionCoveredBySafetyCast(project.child(0));
+            }
+            if (!(input instanceof SlotReference)) {
+                return false;
             }
-        });
+            projectInputExprIds.add(((SlotReference) input).getExprId());
+        }
+        if (isAll) {
+            return true;
+        }
+        Set<ExprId> outputExprIds = 
Sets.newHashSetWithExpectedSize(logicalSetOperation.getOutput().size());
+        logicalSetOperation.getOutput().forEach(output -> 
outputExprIds.add(output.getExprId()));
+        return projectInputExprIds.equals(outputExprIds);
     }
 
     /** doPushProject */
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
index 084756303b6..015690e6434 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import java.util.Objects;
@@ -57,6 +58,14 @@ public class ArrayType extends DataType implements 
ComplexDataType, NestedColumn
         return itemType;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof ArrayType) {
+            return itemType.isInjectiveCastTo(((ArrayType) target).itemType);
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         // Catalog ArrayType defaults containsNull to true via single-arg 
constructor
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
index 49b2a6e72d7..708801a883f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
@@ -31,6 +31,14 @@ public class BooleanType extends PrimitiveType {
     private BooleanType() {
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof BooleanType || target.isIntegralType() || 
target.isFloatLikeType()
+                || (target instanceof DecimalV2Type && ((DecimalV2Type) 
target).getRange() >= 1)
+                || (target instanceof DecimalV3Type && ((DecimalV3Type) 
target).getRange() >= 1)
+                || target.isStringLikeType();
+    }
+
     @Override
     public Type toCatalogDataType() {
         return Type.BOOLEAN;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
index 8ded09c4b63..14295979dae 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
@@ -813,6 +813,10 @@ public abstract class DataType {
 
     public abstract int width();
 
+    public boolean isInjectiveCastTo(DataType target) {
+        return this.equals(target);
+    }
+
     public static List<DataType> trivialTypes() {
         return Type.getTrivialTypes()
                 .stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
index 8a0250d7b44..a93bfda0364 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 
 import java.time.DateTimeException;
@@ -45,6 +46,14 @@ public class DateTimeType extends DateLikeType {
         this.shouldConversion = shouldConversion;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DateTimeType || target instanceof DateTimeV2Type 
|| target instanceof CharacterType) {
+            return true;
+        }
+        return false;
+    }
+
     @Override
     public DataType conversion() {
         if (Config.enable_date_conversion && shouldConversion) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
index f56b4662f8b..13097339554 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
 import 
org.apache.doris.nereids.trees.expressions.literal.format.DateTimeChecker;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.types.coercion.IntegralType;
 import org.apache.doris.nereids.types.coercion.ScaleTimeType;
@@ -128,6 +129,18 @@ public class DateTimeV2Type extends DateLikeType 
implements ScaleTimeType {
         return super.toSql() + "(" + scale + ")";
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DateTimeV2Type) {
+            DateTimeV2Type t2 = (DateTimeV2Type) target;
+            return this.scale <= t2.scale;
+        }
+        if (target instanceof DateTimeType) {
+            return this.scale == 0;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDatetimeV2Type(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
index d127ab16069..c6ce702ebe7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 
 import java.time.DateTimeException;
@@ -45,6 +46,11 @@ public class DateType extends DateLikeType {
         this.shouldConversion = shouldConversion;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof DateType || target instanceof DateV2Type || 
target instanceof CharacterType;
+    }
+
     @Override
     public DataType conversion() {
         if (Config.enable_date_conversion && shouldConversion) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
index b601aaa9f13..b055172f262 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.FractionalType;
 
 import com.google.common.base.Preconditions;
@@ -159,6 +160,19 @@ public class DecimalV2Type extends FractionalType {
         return DecimalV2Type.createDecimalV2Type(range + scale, scale);
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DecimalV2Type) {
+            DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+            return decimalV2Type.getRange() >= this.getRange() && 
decimalV2Type.getScale() >= this.getScale();
+        }
+        if (target instanceof DecimalV3Type) {
+            DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+            return decimalV3Type.getRange() >= this.getRange() && 
decimalV3Type.getScale() >= this.getScale();
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDecimalType(PrimitiveType.DECIMALV2, 
precision, scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
index 3c0a83e95c4..b366568cb35 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.exceptions.NotSupportedException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.FractionalType;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.SessionVariable;
@@ -214,6 +215,19 @@ public class DecimalV3Type extends FractionalType {
         }
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DecimalV2Type) {
+            DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+            return decimalV2Type.getRange() >= this.getRange() && 
decimalV2Type.getScale() >= this.getScale();
+        }
+        if (target instanceof DecimalV3Type) {
+            DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+            return decimalV3Type.getRange() >= this.getRange() && 
decimalV3Type.getScale() >= this.getScale();
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDecimalV3Type(precision, scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
index 176c1db1d0d..fc6e9ba2f94 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import java.util.Objects;
@@ -63,6 +64,15 @@ public class MapType extends DataType implements 
ComplexDataType, NestedColumnPr
         return MapType.of(keyType.conversion(), valueType.conversion());
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof MapType) {
+            MapType mapType = (MapType) target;
+            return keyType.isInjectiveCastTo(mapType.keyType) && 
valueType.isInjectiveCastTo(mapType.valueType);
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return new 
org.apache.doris.catalog.MapType(keyType.toCatalogDataType(), 
valueType.toCatalogDataType());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
index 0c33a6d2dec..13f28c2e06e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.types;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
 import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import com.google.common.collect.ImmutableList;
@@ -84,6 +85,23 @@ public class StructType extends DataType implements 
ComplexDataType, NestedColum
         return new 
StructType(fields.stream().map(StructField::conversion).collect(Collectors.toList()));
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof StructType) {
+            StructType structType = (StructType) target;
+            if (this.fields.size() != structType.fields.size()) {
+                return false;
+            }
+            for (int i = 0; i < fields.size(); i++) {
+                if 
(!this.fields.get(i).getDataType().isInjectiveCastTo(structType.fields.get(i).getDataType()))
 {
+                    return false;
+                }
+            }
+            return true;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return new org.apache.doris.catalog.StructType(fields.stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
index c3c99cad6fc..4f9c09b5f7f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.types.coercion.ScaleTimeType;
 
@@ -46,6 +47,15 @@ public class TimeStampTzType extends DateLikeType implements 
ScaleTimeType {
         this.scale = scale;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof TimeStampTzType) {
+            TimeStampTzType timeStampTzType = (TimeStampTzType) target;
+            return timeStampTzType.getScale() >= this.scale;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createTimeStampTzType(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
index 39f420e6931..af758c02cb5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.TimeV2Literal;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.IntegralType;
 import org.apache.doris.nereids.types.coercion.PrimitiveType;
 import org.apache.doris.nereids.types.coercion.RangeScalable;
@@ -48,6 +49,15 @@ public class TimeV2Type extends PrimitiveType implements 
RangeScalable, ScaleTim
         scale = 0;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof TimeV2Type) {
+            TimeV2Type timeV2Type = (TimeV2Type) target;
+            return timeV2Type.scale >= scale;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createTimeV2Type(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
index 0a745d9cb7f..1ba69fb4f50 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
@@ -116,6 +116,11 @@ public class VariantType extends PrimitiveType {
         this.enableNestedGroup = enableNestedGroup;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target.equals(this) || target instanceof VariantType;
+    }
+
     @Override
     public DataType conversion() {
         return new 
VariantType(predefinedFields.stream().map(VariantField::conversion)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
index 781b1257028..3d8590534f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
@@ -42,6 +42,11 @@ public abstract class CharacterType extends PrimitiveType {
         return len;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         throw new RuntimeException("CharacterType is only used for implicit 
cast.");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
index b1e58805388..fe625fa34bc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
@@ -19,6 +19,8 @@ package org.apache.doris.nereids.types.coercion;
 
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.LargeIntType;
 
 import org.apache.commons.lang3.NotImplementedException;
 
@@ -44,6 +46,19 @@ public class IntegralType extends NumericType {
         return "integral";
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof IntegralType) {
+            return this.equals(target) || ((IntegralType) 
target).widerThan(this);
+        }
+        if (target instanceof DecimalV3Type && !(this instanceof 
LargeIntType)) {
+            DecimalV3Type other = (DecimalV3Type) target;
+            DecimalV3Type self = DecimalV3Type.forType(this);
+            return other.getRange() >= self.getRange();
+        }
+        return target instanceof CharacterType;
+    }
+
     public boolean widerThan(IntegralType other) {
         return this.width() > other.width();
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 53ccdb2d403..7fa7acfa869 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -1135,6 +1135,20 @@ public class ExpressionUtils {
         return expression;
     }
 
+    /**
+     * Strip only casts that preserve distinctness of the child expression.
+     */
+    public static Expression getExpressionCoveredBySafetyCast(Expression 
expression) {
+        while (expression instanceof Cast) {
+            if (((Cast) 
expression).child().getDataType().isInjectiveCastTo(expression.getDataType())) {
+                expression = ((Cast) expression).child();
+            } else {
+                break;
+            }
+        }
+        return expression;
+    }
+
     /**
      * the expressions can be used as runtime filter targets
      */
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
index 328c390d52f..a877348491c 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
@@ -28,6 +28,8 @@ import 
org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DateTimeType;
+import org.apache.doris.nereids.types.DateType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
@@ -107,6 +109,61 @@ public class PushProjectThroughUnionTest {
         }
     }
 
+    @Test
+    public void testCastProjectPushThroughUnionByQualifierAndSafety() {
+        SlotReference unionOutput = new SlotReference(new ExprId(10), "s",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        Alias castProject = new Alias(new ExprId(100),
+                new Cast(unionOutput, BigIntType.INSTANCE), "n");
+        ImmutableList<NamedExpression> projects = 
ImmutableList.of(castProject);
+
+        LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+                ImmutableList.of(unionOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, 
unionAll));
+
+        LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+                ImmutableList.of(unionOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, 
unionDistinct));
+
+        SlotReference dateTimeOutput = new SlotReference(new ExprId(11), "dt",
+                DateTimeType.INSTANCE, true, ImmutableList.of());
+        Alias unsafeCastProject = new Alias(new ExprId(101),
+                new Cast(dateTimeOutput, DateType.INSTANCE), "d");
+        ImmutableList<NamedExpression> unsafeProjects = 
ImmutableList.of(unsafeCastProject);
+
+        LogicalUnion unionAllWithUnsafeCast = new LogicalUnion(Qualifier.ALL,
+                ImmutableList.of(dateTimeOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(unsafeProjects, 
unionAllWithUnsafeCast));
+
+        LogicalUnion unionDistinctWithUnsafeCast = new 
LogicalUnion(Qualifier.DISTINCT,
+                ImmutableList.of(dateTimeOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(unsafeProjects, 
unionDistinctWithUnsafeCast));
+    }
+
+    @Test
+    public void testDistinctProjectRequiresAllOutputSlotsExactlyOnce() {
+        SlotReference firstOutput = new SlotReference(new ExprId(10), "a",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        SlotReference secondOutput = new SlotReference(new ExprId(11), "b",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        ImmutableList<NamedExpression> outputs = ImmutableList.of(firstOutput, 
secondOutput);
+        LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+                outputs, ImmutableList.of(), ImmutableList.of(), false, 
ImmutableList.of());
+        LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+                outputs, ImmutableList.of(), ImmutableList.of(), false, 
ImmutableList.of());
+
+        ImmutableList<NamedExpression> duplicateProjects = ImmutableList.of(
+                new Alias(new ExprId(100), new Cast(firstOutput, 
BigIntType.INSTANCE), "a1"),
+                new Alias(new ExprId(101), new Cast(firstOutput, 
BigIntType.INSTANCE), "a2"));
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(duplicateProjects, 
unionAll));
+        
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(duplicateProjects,
 unionDistinct));
+
+        ImmutableList<NamedExpression> permutationProjects = ImmutableList.of(
+                new Alias(new ExprId(102), new Cast(secondOutput, 
BigIntType.INSTANCE), "b"),
+                new Alias(new ExprId(103), new Cast(firstOutput, 
BigIntType.INSTANCE), "a"));
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(permutationProjects,
 unionDistinct));
+    }
+
     private LogicalUnion findUnion(Plan p) {
         if (p instanceof LogicalUnion) {
             return (LogicalUnion) p;
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
index 59509fce805..9720fa1c550 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
@@ -138,6 +138,62 @@ public class DataTypeTest {
 
     }
 
+    @Test
+    public void testIsInjectiveCastToForPrimitiveTypes() {
+        assertSafeCast(IntegerType.INSTANCE, IntegerType.INSTANCE);
+        assertSafeCast(IntegerType.INSTANCE, BigIntType.INSTANCE);
+        assertUnsafeCast(BigIntType.INSTANCE, IntegerType.INSTANCE);
+        assertSafeCast(IntegerType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(10, 0));
+        assertUnsafeCast(IntegerType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(9, 0));
+        assertUnsafeCast(LargeIntType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(38, 0));
+
+        assertSafeCast(BooleanType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(1, 0));
+        assertUnsafeCast(BooleanType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(1, 1));
+
+        assertSafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(8, 3));
+        assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(6, 1));
+        assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(5, 2));
+
+        assertSafeCast(DateTimeType.INSTANCE, DateTimeV2Type.of(0));
+        assertSafeCast(DateTimeV2Type.of(0), DateTimeType.INSTANCE);
+        assertSafeCast(DateTimeV2Type.of(3), DateTimeV2Type.of(6));
+        assertUnsafeCast(DateTimeV2Type.of(3), DateTimeType.INSTANCE);
+        assertUnsafeCast(DateTimeType.INSTANCE, DateType.INSTANCE);
+
+        assertSafeCast(VarcharType.createVarcharType(10), 
VarcharType.createVarcharType(20));
+        assertSafeCast(VarcharType.createVarcharType(10), StringType.INSTANCE);
+        assertSafeCast(VarcharType.createVarcharType(20), 
VarcharType.createVarcharType(10));
+        assertSafeCast(StringType.INSTANCE, VarcharType.createVarcharType(10));
+    }
+
+    @Test
+    public void testIsInjectiveCastToForComplexTypes() {
+        assertSafeCast(ArrayType.of(IntegerType.INSTANCE), 
ArrayType.of(BigIntType.INSTANCE));
+        assertUnsafeCast(ArrayType.of(BigIntType.INSTANCE), 
ArrayType.of(IntegerType.INSTANCE));
+
+        assertSafeCast(MapType.of(IntegerType.INSTANCE, 
VarcharType.createVarcharType(10)),
+                MapType.of(BigIntType.INSTANCE, StringType.INSTANCE));
+        assertUnsafeCast(MapType.of(BigIntType.INSTANCE, 
VarcharType.createVarcharType(10)),
+                MapType.of(IntegerType.INSTANCE, StringType.INSTANCE));
+
+        StructType intStringStruct = new StructType(ImmutableList.of(
+                new StructField("a", IntegerType.INSTANCE, true, ""),
+                new StructField("b", VarcharType.createVarcharType(10), true, 
"")));
+        StructType bigintStringStruct = new StructType(ImmutableList.of(
+                new StructField("a", BigIntType.INSTANCE, true, ""),
+                new StructField("b", StringType.INSTANCE, true, "")));
+        StructType intOnlyStruct = new StructType(ImmutableList.of(
+                new StructField("a", IntegerType.INSTANCE, true, "")));
+
+        assertSafeCast(intStringStruct, bigintStringStruct);
+        assertUnsafeCast(bigintStringStruct, intStringStruct);
+        assertUnsafeCast(intOnlyStruct, intStringStruct);
+
+        assertSafeCast(ArrayType.of(IntegerType.INSTANCE), 
StringType.INSTANCE);
+        assertSafeCast(MapType.of(IntegerType.INSTANCE, StringType.INSTANCE), 
StringType.INSTANCE);
+        assertSafeCast(intStringStruct, StringType.INSTANCE);
+    }
+
     @Test
     public void testAnyAccept() {
         AnyDataType dateType = AnyDataType.INSTANCE_WITHOUT_INDEX;
@@ -654,4 +710,14 @@ public class DataTypeTest {
         DataType type = ArrayType.of(MapType.of(VarcharType.SYSTEM_DEFAULT, 
IntegerType.INSTANCE));
         Assertions.assertDoesNotThrow(type::validateDataType);
     }
+
+    private void assertSafeCast(DataType source, DataType target) {
+        Assertions.assertTrue(source.isInjectiveCastTo(target), source.toSql() 
+ " should safely cast to "
+                + target.toSql());
+    }
+
+    private void assertUnsafeCast(DataType source, DataType target) {
+        Assertions.assertFalse(source.isInjectiveCastTo(target), 
source.toSql() + " should not safely cast to "
+                + target.toSql());
+    }
 }
diff --git a/regression-test/data/nereids_syntax_p0/set_operation.out 
b/regression-test/data/nereids_syntax_p0/set_operation.out
index 5afc4fac2ad..b21f59c0f65 100644
--- a/regression-test/data/nereids_syntax_p0/set_operation.out
+++ b/regression-test/data/nereids_syntax_p0/set_operation.out
@@ -592,6 +592,14 @@ hell0
 -- !union45 --
 2
 
+-- !union46 --
+2020-01-01
+2020-01-01
+
+-- !union47 --
+1      1
+1      1
+
 -- !check_child_col_order --
 205548764.21875        3601
 53950855.65625 3602
@@ -599,4 +607,3 @@ hell0
 -- !intersect_case --
 0
 1
-
diff --git a/regression-test/suites/nereids_syntax_p0/set_operation.groovy 
b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
index 360dfd55d93..e9cddece35f 100644
--- a/regression-test/suites/nereids_syntax_p0/set_operation.groovy
+++ b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
@@ -291,6 +291,27 @@ suite("set_operation") {
         select count(*) from (select 1, 2 union select 1,1 ) a;
     """
 
+    // do not push non-injective cast project below UNION DISTINCT.
+    // The two datetime values are distinct before the outer cast, but become
+    // equal after casting to date. The correct result keeps both rows.
+    order_qt_union46 """
+        select cast(dt as date) from (
+            select cast('2020-01-01 00:00:00' as datetime) dt
+            union
+            select cast('2020-01-01 01:00:00' as datetime) dt
+        ) t
+    """
+
+    // The project duplicates one UNION output and drops the other. Pushing it
+    // below UNION DISTINCT would collapse the two rows into one.
+    order_qt_union47 """
+        select cast(a as bigint), cast(a as bigint) from (
+            select 1 a, 2 b
+            union
+            select 1 a, 3 b
+        ) t
+    """
+
     def tables = [
             "dwd_daytable",
     ]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to