This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new afc059e002c branch-3.1: [fix](nereids) fix scalar subquery output 
nullable #51928 (#52492)
afc059e002c is described below

commit afc059e002c100c219d71d2a80e43de81f41afe5
Author: yujun <[email protected]>
AuthorDate: Mon Jun 30 19:20:51 2025 +0800

    branch-3.1: [fix](nereids) fix scalar subquery output nullable #51928 
(#52492)
    
    cherry pick from #51928
---
 .../nereids/rules/analysis/SubqueryToApply.java    |  15 ++--
 .../rewrite/AggScalarSubQueryToWindowFunction.java |  11 +--
 .../trees/copier/LogicalPlanDeepCopier.java        |   4 +-
 .../nereids/trees/expressions/ScalarSubquery.java  |  72 ++++++++++++----
 .../nereids/trees/plans/logical/LogicalApply.java  |  29 ++++---
 .../java/org/apache/doris/nereids/util/Utils.java  |  25 +-----
 .../rules/analysis/AnalyzeSubQueryTest.java        |  93 +++++++++++++++++++++
 .../adjust_nullable/test_subquery_nullable.out     | Bin 0 -> 6295 bytes
 .../adjust_nullable/test_subquery_nullable.groovy  |  92 ++++++++++++++++++++
 9 files changed, 278 insertions(+), 63 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
index 559cec207e9..c6cf687819a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubqueryToApply.java
@@ -414,6 +414,10 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
         // if needRuntimeAnyValue is true, we will add it to the project list
         boolean needRuntimeAnyValue = false;
         NamedExpression oldSubqueryOutput = 
subquery.getQueryPlan().getOutput().get(0);
+        if (subquery instanceof ScalarSubquery) {
+            // scalar sub query may adjust output slot's nullable.
+            oldSubqueryOutput = ((ScalarSubquery) 
subquery).getOutputSlotAdjustNullable();
+        }
         Slot countSlot = null;
         Slot anyValueSlot = null;
         Optional<Expression> newConjunct = conjunct;
@@ -427,9 +431,10 @@ public class SubqueryToApply implements 
AnalysisRuleFactory {
                 // but COUNT function is always not nullable.
                 // so wrap COUNT with Nvl to ensure its result is 0 instead of 
null to get the correct result
                 if (conjunct.isPresent()) {
-                    Map<Expression, Expression> replaceMap = new HashMap<>();
-                    NamedExpression agg = ((ScalarSubquery) 
subquery).getTopLevelScalarAggFunction().get();
+                    NamedExpression agg = 
ScalarSubquery.getTopLevelScalarAggFunction(
+                            subquery.getQueryPlan(), 
subquery.getCorrelateSlots()).get();
                     if (agg instanceof Alias) {
+                        Map<Expression, Expression> replaceMap = new 
HashMap<>();
                         if (((Alias) agg).child() instanceof 
NotNullableAggregateFunction) {
                             NotNullableAggregateFunction notNullableAggFunc =
                                     (NotNullableAggregateFunction) ((Alias) 
agg).child();
@@ -451,9 +456,9 @@ public class SubqueryToApply implements AnalysisRuleFactory 
{
                                 replaceMap.put(oldSubqueryOutput, new 
Nvl(oldSubqueryOutput,
                                         
notNullableAggFunc.resultForEmptyInput()));
                             }
-                        }
-                        if (!replaceMap.isEmpty()) {
-                            newConjunct = 
Optional.of(ExpressionUtils.replace(conjunct.get(), replaceMap));
+                            if (!replaceMap.isEmpty()) {
+                                newConjunct = 
Optional.of(ExpressionUtils.replace(conjunct.get(), replaceMap));
+                            }
                         }
                     }
                 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java
index ce93d25cb97..d823f0af88e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggScalarSubQueryToWindowFunction.java
@@ -246,7 +246,7 @@ public class AggScalarSubQueryToWindowFunction extends 
DefaultPlanRewriter<JobCo
      * 2. outer table list - inner table list should only remain 1 table
      * 3. the remaining table in step 2 should be correlated table for inner 
plan
      */
-    private boolean checkRelation(List<Expression> correlatedSlots) {
+    private boolean checkRelation(List<Slot> correlatedSlots) {
         List<CatalogRelation> outerTables = 
outerPlans.stream().filter(CatalogRelation.class::isInstance)
                 .map(CatalogRelation.class::cast)
                 .collect(Collectors.toList());
@@ -274,9 +274,7 @@ public class AggScalarSubQueryToWindowFunction extends 
DefaultPlanRewriter<JobCo
                 .filter(node -> outerIds.contains(node.getTable().getId()))
                 .map(LogicalRelation.class::cast)
                 
.map(LogicalRelation::getOutputExprIdSet).flatMap(Collection::stream).collect(Collectors.toSet());
-        return ExpressionUtils.collect(correlatedSlots, 
NamedExpression.class::isInstance).stream()
-                .map(NamedExpression.class::cast)
-                .allMatch(e -> 
correlatedRelationOutput.contains(e.getExprId()));
+        return correlatedSlots.stream().allMatch(e -> 
correlatedRelationOutput.contains(e.getExprId()));
     }
 
     private void createSlotMapping(List<CatalogRelation> outerTables, 
List<CatalogRelation> innerTables) {
@@ -366,10 +364,9 @@ public class AggScalarSubQueryToWindowFunction extends 
DefaultPlanRewriter<JobCo
         return windowFilter;
     }
 
-    private WindowExpression createWindowFunction(List<Expression> 
correlatedSlots, AggregateFunction function) {
+    private WindowExpression createWindowFunction(List<Slot> correlatedSlots, 
AggregateFunction function) {
         // partition by clause is set by all the correlated slots.
-        
Preconditions.checkArgument(correlatedSlots.stream().allMatch(Slot.class::isInstance));
-        return new WindowExpression(function, correlatedSlots, 
Collections.emptyList());
+        return new WindowExpression(function, 
ImmutableList.copyOf(correlatedSlots), Collections.emptyList());
     }
 
     private static class ExpressionIdenticalChecker extends 
DefaultExpressionVisitor<Boolean, Expression> {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
index e7b755a6d53..2dd581e38d7 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java
@@ -125,8 +125,8 @@ public class LogicalPlanDeepCopier extends 
DefaultPlanRewriter<DeepCopierContext
     public Plan visitLogicalApply(LogicalApply<? extends Plan, ? extends Plan> 
apply, DeepCopierContext context) {
         Plan left = apply.left().accept(this, context);
         Plan right = apply.right().accept(this, context);
-        List<Expression> correlationSlot = apply.getCorrelationSlot().stream()
-                .map(s -> ExpressionDeepCopier.INSTANCE.deepCopy(s, context))
+        List<Slot> correlationSlot = apply.getCorrelationSlot().stream()
+                .map(s -> (Slot) ExpressionDeepCopier.INSTANCE.deepCopy(s, 
context))
                 .collect(ImmutableList.toImmutableList());
         Optional<Expression> compareExpr = apply.getCompareExpr()
                 .map(f -> ExpressionDeepCopier.INSTANCE.deepCopy(f, context));
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
index a1bfbf49029..4ded2447731 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ScalarSubquery.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.trees.expressions;
 
 import org.apache.doris.nereids.exceptions.UnboundException;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.NotNullableAggregateFunction;
 import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.trees.plans.Plan;
@@ -29,12 +30,14 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
 import org.apache.doris.nereids.types.DataType;
 
 import com.google.common.base.Preconditions;
+import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
+import java.util.function.Supplier;
 
 /**
  * A subquery that will return only one row and one column.
@@ -47,6 +50,9 @@ public class ScalarSubquery extends SubqueryExpr implements 
LeafExpression {
     // indicate if the subquery has limit 1 clause but it's been eliminated in 
previous process step
     private final boolean limitOneIsEliminated;
 
+    private final Supplier<Slot> queryPlanSlot = Suppliers.memoize(
+            () -> getScalarQueryOutputAdjustNullable(queryPlan, 
correlateSlots));
+
     public ScalarSubquery(LogicalPlan subquery) {
         this(subquery, ImmutableList.of(), false);
     }
@@ -75,10 +81,10 @@ public class ScalarSubquery extends SubqueryExpr implements 
LeafExpression {
     /**
     * getTopLevelScalarAggFunction
     */
-    public Optional<NamedExpression> getTopLevelScalarAggFunction() {
-        Plan plan = findTopLevelScalarAgg(queryPlan, 
ImmutableSet.copyOf(correlateSlots));
-        if (plan != null) {
-            LogicalAggregate aggregate = (LogicalAggregate) plan;
+    public static Optional<NamedExpression> getTopLevelScalarAggFunction(Plan 
queryPlan,
+            List<Slot> correlateSlots) {
+        LogicalAggregate<?> aggregate = findTopLevelScalarAgg(queryPlan, 
ImmutableSet.copyOf(correlateSlots));
+        if (aggregate != null) {
             Preconditions.checkState(aggregate.getAggregateFunctions().size() 
== 1,
                     "in scalar subquery, should only return 1 column 1 row, "
                             + "but we found multiple columns ", 
aggregate.getOutputExpressions());
@@ -88,10 +94,15 @@ public class ScalarSubquery extends SubqueryExpr implements 
LeafExpression {
         }
     }
 
+    @Override
+    public Expression getSubqueryOutput() {
+        return typeCoercionExpr.orElseGet(this::getOutputSlotAdjustNullable);
+    }
+
     @Override
     public DataType getDataType() throws UnboundException {
         Preconditions.checkArgument(queryPlan.getOutput().size() == 1);
-        return 
typeCoercionExpr.orElse(queryPlan.getOutput().get(0)).getDataType();
+        return getSubqueryOutput().getDataType();
     }
 
     @Override
@@ -110,10 +121,11 @@ public class ScalarSubquery extends SubqueryExpr 
implements LeafExpression {
 
     @Override
     public Expression withTypeCoercion(DataType dataType) {
-        return new ScalarSubquery(queryPlan, correlateSlots,
-                dataType == queryPlan.getOutput().get(0).getDataType()
-                    ? Optional.of(queryPlan.getOutput().get(0))
-                    : Optional.of(new Cast(queryPlan.getOutput().get(0), 
dataType)), limitOneIsEliminated);
+        Optional<Expression> newTypeCoercionExpr = typeCoercionExpr;
+        if (!getDataType().equals(dataType)) {
+            newTypeCoercionExpr = Optional.of(new Cast(getSubqueryOutput(), 
dataType));
+        }
+        return new ScalarSubquery(queryPlan, correlateSlots, 
newTypeCoercionExpr, limitOneIsEliminated);
     }
 
     @Override
@@ -121,22 +133,52 @@ public class ScalarSubquery extends SubqueryExpr 
implements LeafExpression {
         return new ScalarSubquery(subquery, correlateSlots, typeCoercionExpr, 
limitOneIsEliminated);
     }
 
+    public Slot getOutputSlotAdjustNullable() {
+        return queryPlanSlot.get();
+    }
+
+    /**
+     *  get query plan output slot, adjust it to
+     *  1. true(no adjust), when it has top agg, and the agg function is 
NotNullableAggregateFunction
+     *     for example: `t1.a = (select count(t2.x) from t2)`,  count(t2.x) is 
always not null, even if t2 contain 0 row
+     *  2. false, otherwise.
+     *     for example: `t1.a = (select t2.y from t2 limit 1)`, even if t2.y 
is not nullable, but if t2 contain 0 row,
+     *     the sub query t2 output is still null.
+     */
+    public static Slot getScalarQueryOutputAdjustNullable(Plan queryPlan, 
List<Slot> correlateSlots) {
+        Slot output = queryPlan.getOutput().get(0);
+        boolean nullable = true;
+        Optional<NamedExpression> aggOpt = 
getTopLevelScalarAggFunction(queryPlan, correlateSlots);
+        if (aggOpt.isPresent()) {
+            NamedExpression agg = aggOpt.get();
+            if (agg.getExprId().equals(output.getExprId())
+                    && agg instanceof Alias
+                    && ((Alias) agg).child() instanceof 
NotNullableAggregateFunction) {
+                nullable = false;
+            }
+        }
+
+        return output.withNullable(nullable);
+    }
+
     /**
-     * for correlated subquery, we define top level scalar agg as if it meets 
the both 2 conditions:
-     * 1. The agg or its child contains correlated slots
+     * for subquery, we define top level scalar agg as if it meets the both 2 
conditions:
+     * 1. The agg or its child contains correlated slots(un-correlated sub 
query's correlated slot is empty)
      * 2. only project, sort and subquery alias node can be agg's parent
      */
-    public static Plan findTopLevelScalarAgg(Plan plan, ImmutableSet<Slot> 
slots) {
+    public static LogicalAggregate<?> findTopLevelScalarAgg(Plan plan, 
ImmutableSet<Slot> slots) {
         if (plan instanceof LogicalAggregate) {
-            if (((LogicalAggregate<?>) plan).getGroupByExpressions().isEmpty() 
&& plan.containsSlots(slots)) {
-                return plan;
+            LogicalAggregate<?> agg = (LogicalAggregate<?>) plan;
+            if (agg.getGroupByExpressions().isEmpty()
+                    && (plan.containsSlots(slots) || slots.isEmpty())) {
+                return agg;
             } else {
                 return null;
             }
         } else if (plan instanceof LogicalProject || plan instanceof 
LogicalSubQueryAlias
                 || plan instanceof LogicalSort) {
             for (Plan child : plan.children()) {
-                Plan result = findTopLevelScalarAgg(child, slots);
+                LogicalAggregate<?> result = findTopLevelScalarAgg(child, 
slots);
                 if (result != null) {
                     return result;
                 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
index 0b12e225311..59d72f51c95 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalApply.java
@@ -21,6 +21,7 @@ import org.apache.doris.nereids.memo.GroupExpression;
 import org.apache.doris.nereids.properties.LogicalProperties;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
+import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.PlanType;
@@ -61,7 +62,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
     private final Optional<Expression> typeCoercionExpr;
 
     // correlation column
-    private final List<Expression> correlationSlot;
+    private final List<Slot> correlationSlot;
     // correlation Conjunction
     private final Optional<Expression> correlationFilter;
     // The slot replaced by the subquery in MarkJoin
@@ -83,7 +84,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
 
     private LogicalApply(Optional<GroupExpression> groupExpression,
             Optional<LogicalProperties> logicalProperties,
-            List<Expression> correlationSlot, SubQueryType subqueryType, 
boolean isNot,
+            List<Slot> correlationSlot, SubQueryType subqueryType, boolean 
isNot,
             Optional<Expression> compareExpr, Optional<Expression> 
typeCoercionExpr,
             Optional<Expression> correlationFilter,
             Optional<MarkJoinSlotReference> markJoinSlotReference,
@@ -107,7 +108,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
         this.isMarkJoinSlotNotNull = isMarkJoinSlotNotNull;
     }
 
-    public LogicalApply(List<Expression> correlationSlot, SubQueryType 
subqueryType, boolean isNot,
+    public LogicalApply(List<Slot> correlationSlot, SubQueryType subqueryType, 
boolean isNot,
             Optional<Expression> compareExpr, Optional<Expression> 
typeCoercionExpr,
             Optional<Expression> correlationFilter, 
Optional<MarkJoinSlotReference> markJoinSlotReference,
             boolean needAddSubOutputToProjects, boolean inProject, boolean 
isMarkJoinSlotNotNull,
@@ -129,7 +130,7 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
         return typeCoercionExpr.orElseGet(() -> right().getOutput().get(0));
     }
 
-    public List<Expression> getCorrelationSlot() {
+    public List<Slot> getCorrelationSlot() {
         return correlationSlot;
     }
 
@@ -187,13 +188,19 @@ public class LogicalApply<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
 
     @Override
     public List<Slot> computeOutput() {
-        return ImmutableList.<Slot>builder()
-                .addAll(left().getOutput())
-                .addAll(markJoinSlotReference.isPresent()
-                    ? ImmutableList.of(markJoinSlotReference.get()) : 
ImmutableList.of())
-                .addAll(needAddSubOutputToProjects
-                    ? ImmutableList.of(right().getOutput().get(0)) : 
ImmutableList.of())
-                .build();
+        ImmutableList.Builder<Slot> builder = ImmutableList.builder();
+        builder.addAll(left().getOutput());
+        if (markJoinSlotReference.isPresent()) {
+            builder.add(markJoinSlotReference.get());
+        }
+        if (needAddSubOutputToProjects) {
+            if (isScalar()) {
+                
builder.add(ScalarSubquery.getScalarQueryOutputAdjustNullable(right(), 
correlationSlot));
+            } else {
+                builder.add(right().getOutput().get(0));
+            }
+        }
+        return builder.build();
     }
 
     @Override
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
index c111839fc50..0040c35992b 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
@@ -18,11 +18,9 @@
 package org.apache.doris.nereids.util;
 
 import org.apache.doris.nereids.exceptions.AnalysisException;
-import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Not;
 import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
 
 import com.google.common.base.CaseFormat;
@@ -202,7 +200,7 @@ public class Utils {
      * return abs(t2.d)
      */
     public static List<Expression> getUnCorrelatedExprs(List<Expression> 
correlatedPredicates,
-                                                        List<Expression> 
correlatedSlots) {
+                                                        List<Slot> 
correlatedSlots) {
         List<Expression> unCorrelatedExprs = new ArrayList<>();
         correlatedPredicates.forEach(predicate -> {
             if (!(predicate instanceof BinaryExpression) && (!(predicate 
instanceof Not)
@@ -240,27 +238,8 @@ public class Utils {
         return unCorrelatedExprs;
     }
 
-    private static List<Expression> collectCorrelatedSlotsFromChildren(
-            BinaryExpression binaryExpression, List<Expression> 
correlatedSlots) {
-        List<Expression> slots = new ArrayList<>();
-        if (binaryExpression.left().anyMatch(correlatedSlots::contains)) {
-            if (binaryExpression.right() instanceof SlotReference) {
-                slots.add(binaryExpression.right());
-            } else if (binaryExpression.right() instanceof Cast) {
-                slots.add(((Cast) binaryExpression.right()).child());
-            }
-        } else {
-            if (binaryExpression.left() instanceof SlotReference) {
-                slots.add(binaryExpression.left());
-            } else if (binaryExpression.left() instanceof Cast) {
-                slots.add(((Cast) binaryExpression.left()).child());
-            }
-        }
-        return slots;
-    }
-
     public static Map<Boolean, List<Expression>> splitCorrelatedConjuncts(
-            Set<Expression> conjuncts, List<Expression> slots) {
+            Set<Expression> conjuncts, List<Slot> slots) {
         return conjuncts.stream().collect(Collectors.partitioningBy(
                 expr -> expr.anyMatch(slots::contains)));
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
index 138db44863f..d0ce4bedc91 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeSubQueryTest.java
@@ -25,8 +25,12 @@ import org.apache.doris.nereids.parser.NereidsParser;
 import org.apache.doris.nereids.properties.PhysicalProperties;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.ExprId;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.util.FieldChecker;
@@ -36,6 +40,8 @@ import org.apache.doris.nereids.util.PlanChecker;
 import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Lists;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
 import java.util.List;
@@ -76,6 +82,15 @@ public class AnalyzeSubQueryTest extends TestWithFeService 
implements MemoPatter
                         + "DISTRIBUTED BY HASH(id) BUCKETS 1\n"
                         + "PROPERTIES (\n"
                         + "  \"replication_num\" = \"1\"\n"
+                        + ")\n",
+                "CREATE TABLE IF NOT EXISTS T3 (\n"
+                        + "    id bigint not null,\n"
+                        + "    score bigint not null\n"
+                        + ")\n"
+                        + "DUPLICATE KEY(id)\n"
+                        + "DISTRIBUTED BY HASH(id) BUCKETS 1\n"
+                        + "PROPERTIES (\n"
+                        + "  \"replication_num\" = \"1\"\n"
                         + ")\n"
         );
     }
@@ -185,4 +200,82 @@ public class AnalyzeSubQueryTest extends TestWithFeService 
implements MemoPatter
                     )
                 );
     }
+
+    @Test
+    public void testScalarSubquerySlotNullable() {
+        List<String> nullableSqls = ImmutableList.of(
+                // project list
+                "select (select T3.id as k from T3 limit 1) from T1",
+                "select (select T3.id as k from T3 where T3.score = T1.score 
limit 1) from T1",
+                "select (select sum(T3.id) as k from T3) from T1",
+                "select (select sum(T3.id) as k from T3 where T3.score = 
T1.score) from T1",
+                "select (select sum(T3.id) as k from T3 group by T3.score 
limit 1) from T1",
+                "select (select sum(T3.id) as k from T3 group by T3.score 
having T3.score = T1.score + 10 limit 1) from T1",
+                "select (select count(T3.id) as k from T3 group by T3.score 
limit 1) from T1",
+                "select (select count(T3.id) as k from T3 group by T3.score 
having T3.score = T1.score + 10 limit 1) from T1",
+
+                // filter
+                "select * from T1 where T1.id > (select T3.id as k from T3 
limit 1)",
+                "select * from T1 where T1.id > (select T3.id as k from T3 
where T3.score = T1.score limit 1)",
+                "select * from T1 where T1.id > (select sum(T3.id) as k from 
T3)",
+                "select * from T1 where T1.id > (select sum(T3.id) as k from 
T3 where T3.score = T1.score)",
+                "select * from T1 where T1.id > (select sum(T3.id) as k from 
T3 group by T3.score limit 1)",
+                "select * from T1 where T1.id > (select sum(T3.id) as k from 
T3 group by T3.score having T3.score = T1.score + 10 limit 1)",
+                "select * from T1 where T1.id > (select count(T3.id) as k from 
T3 group by T3.score limit 1)",
+                "select * from T1 where T1.id > (select count(T3.id) as k from 
T3 group by T3.score having T3.score = T1.score + 10 limit 1)"
+        );
+
+        List<String> notNullableSqls = ImmutableList.of(
+                // project
+                "select (select count(T3.id) as k from T3) from T1",
+                "select (select count(T3.id) as k from T3 where T3.score = 
T1.score) from T1",
+
+                // filter
+                "select * from T1 where T1.id > (select count(T3.id) as k from 
T3)",
+                "select * from T1 where T1.id > (select count(T3.id) as k from 
T3 where T3.score = T1.score)"
+        );
+
+        for (String sql : nullableSqls) {
+            checkScalarSubquerySlotNullable(sql, true);
+        }
+
+        for (String sql : notNullableSqls) {
+            checkScalarSubquerySlotNullable(sql, false);
+        }
+    }
+
+    private void checkScalarSubquerySlotNullable(String sql, boolean 
outputNullable) {
+        Plan root = PlanChecker.from(connectContext)
+                .analyze(sql)
+                .applyTopDown(new LogicalSubQueryAliasToLogicalProject())
+                .getPlan();
+        List<LogicalProject<?>> projectList = Lists.newArrayList();
+        root.foreach(plan -> {
+            if (plan instanceof LogicalProject && plan.child(0) instanceof 
LogicalApply) {
+                projectList.add((LogicalProject<?>) plan);
+                return true;
+            } else {
+                return false;
+            }
+        });
+
+        Assertions.assertEquals(1, projectList.size());
+        LogicalProject<?> project = projectList.get(0);
+        LogicalApply<?, ?> apply = (LogicalApply<?, ?>) project.child();
+
+        Assertions.assertNotNull(project);
+        Assertions.assertNotNull(apply);
+
+        List<String> slotKName = ImmutableList.of("k", "any_value(k)", 
"ifnull(k, 0)");
+        NamedExpression output = project.getProjects().stream()
+                .filter(e -> slotKName.contains(e.getName()))
+                .findFirst().orElse(null);
+        Assertions.assertNotNull(output);
+        Assertions.assertEquals(outputNullable, output.nullable());
+        output = apply.getOutput().stream()
+                .filter(e -> slotKName.contains(e.getName()))
+                .findFirst().orElse(null);
+        Assertions.assertNotNull(output);
+        Assertions.assertEquals(outputNullable, output.nullable());
+    }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
 
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
new file mode 100644
index 00000000000..c71b3fa7db7
Binary files /dev/null and 
b/regression-test/data/nereids_rules_p0/adjust_nullable/test_subquery_nullable.out
 differ
diff --git 
a/regression-test/suites/nereids_rules_p0/adjust_nullable/test_subquery_nullable.groovy
 
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_subquery_nullable.groovy
new file mode 100644
index 00000000000..b30fb67b371
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/adjust_nullable/test_subquery_nullable.groovy
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite('test_subquery_nullable') {
+    sql 'DROP TABLE IF EXISTS test_subquery_nullable_t1 FORCE'
+    sql 'DROP TABLE IF EXISTS test_subquery_nullable_t2 FORCE'
+    sql "CREATE TABLE test_subquery_nullable_t1(a int not null, b int not 
null, c int not null) distributed by hash(a) properties('replication_num' = 
'1')"
+    sql "CREATE TABLE test_subquery_nullable_t2(x int not null, y int not 
null, z int not null) distributed by hash(x) properties('replication_num' = 
'1')"
+    sql 'INSERT INTO test_subquery_nullable_t1 values(1, 1, 1), (2, 2, 2)'
+    sql 'INSERT INTO test_subquery_nullable_t2 values(1, 1, 1), (2, 2, 2)'
+    sql "SET detail_shape_nodes='PhysicalProject'"
+    order_qt_uncorrelate_scalar_subquery '''
+        with cte1 as (select a, (select x from test_subquery_nullable_t2 where 
x > 1000 limit 1) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    qt_uncorrelate_scalar_subquery_shape '''explain shape plan
+        with cte1 as (select a, (select x from test_subquery_nullable_t2 where 
x > 1000 limit 1) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    order_qt_correlate_scalar_subquery '''
+        with cte1 as (select a, (select x from test_subquery_nullable_t2 where 
x > 1000 and test_subquery_nullable_t1.b = test_subquery_nullable_t2.y limit 1) 
as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    qt_correlate_scalar_subquery_shape '''explain shape plan
+        with cte1 as (select a, (select x from test_subquery_nullable_t2 where 
x > 1000 and test_subquery_nullable_t1.b = test_subquery_nullable_t2.y limit 1) 
as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    order_qt_uncorrelate_top_nullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select sum(x) from test_subquery_nullable_t2 
where x > 1000) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    qt_uncorrelate_top_nullable_agg_scalar_subquery_shape '''explain shape plan
+        with cte1 as (select a, (select sum(x) from test_subquery_nullable_t2 
where x > 1000) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1
+    '''
+    order_qt_correlate_top_nullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select sum(x) from test_subquery_nullable_t2 
where x > 1000 and test_subquery_nullable_t1.b = test_subquery_nullable_t2.y) 
as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    qt_correlate_top_nullable_agg_scalar_subquery_shape '''explain shape plan
+        with cte1 as (select a, (select sum(x) from test_subquery_nullable_t2 
where x > 1000 and test_subquery_nullable_t1.b = test_subquery_nullable_t2.y) 
as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    order_qt_uncorrelate_top_notnullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    qt_uncorrelate_top_notnullable_agg_scalar_subquery_shape '''explain shape 
plan
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    order_qt_correlate_top_notnullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 and test_subquery_nullable_t1.b = 
test_subquery_nullable_t2.y) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    qt_correlate_top_notnullable_agg_scalar_subquery_shape '''explain shape 
plan
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 and test_subquery_nullable_t1.b = 
test_subquery_nullable_t2.y) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    order_qt_uncorrelate_notop_notnullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 group by x) as x from 
test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    qt_uncorrelate_notop_notnullable_agg_scalar_subquery_shape '''explain 
shape plan
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 group by x) as x from 
test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    order_qt_correlate_notop_notnullable_agg_scalar_subquery '''
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 group by x having 
test_subquery_nullable_t1.a + test_subquery_nullable_t1.b = 
test_subquery_nullable_t2.x) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    qt_correlate_notop_notnullable_agg_scalar_subquery_shape '''explain shape 
plan
+        with cte1 as (select a, (select count(x) from 
test_subquery_nullable_t2 where x > 1000 group by x having 
test_subquery_nullable_t1.a + test_subquery_nullable_t1.b = 
test_subquery_nullable_t2.x) as x from test_subquery_nullable_t1)
+        select a, x > 10 and x < 1 from cte1;
+    '''
+    sql 'DROP TABLE IF EXISTS test_subquery_nullable_t1 FORCE'
+    sql 'DROP TABLE IF EXISTS test_subquery_nullable_t2 FORCE'
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to