This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new e2d37394a7f branch-4.1: [fix](aggregate) Fix nullable aggregate
visitor dispatch #64885 (#65023)
e2d37394a7f is described below
commit e2d37394a7ff2b9d38af034e6b72baea1161f92f
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Jul 1 17:43:52 2026 +0800
branch-4.1: [fix](aggregate) Fix nullable aggregate visitor dispatch #64885
(#65023)
Cherry-picked from #64885
Co-authored-by: morrySnow <[email protected]>
---
.../visitor/AggregateFunctionVisitor.java | 6 +-
.../visitor/AggregateFunctionVisitorTest.java | 114 +++++++++++++++++++++
2 files changed, 117 insertions(+), 3 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
index 22fb409caed..c4d1ef0d4c9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitor.java
@@ -149,15 +149,15 @@ public interface AggregateFunctionVisitor<R, C> {
}
default R visitBoolAnd(BoolAnd boolAnd, C context) {
- return visitAggregateFunction(boolAnd, context);
+ return visitNullableAggregateFunction(boolAnd, context);
}
default R visitBoolOr(BoolOr boolOr, C context) {
- return visitAggregateFunction(boolOr, context);
+ return visitNullableAggregateFunction(boolOr, context);
}
default R visitBoolXor(BoolXor boolXor, C context) {
- return visitAggregateFunction(boolXor, context);
+ return visitNullableAggregateFunction(boolXor, context);
}
default R visitCollectList(CollectList collectList, C context) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java
new file mode 100644
index 00000000000..d77f1007a80
--- /dev/null
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/visitor/AggregateFunctionVisitorTest.java
@@ -0,0 +1,114 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.visitor;
+
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.NullableAggregateFunction;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+class AggregateFunctionVisitorTest {
+ private static final String AGG_PACKAGE =
"org.apache.doris.nereids.trees.expressions.functions.agg.";
+ private static final String COMBINATOR_PACKAGE =
"org.apache.doris.nereids.trees.expressions.functions.combinator.";
+
+ @Test
+ void testNullableAggregateFunctionsVisitNullableDefault() throws Exception
{
+ AggregateFunctionVisitor<String, Void> visitor = new
AggregateFunctionVisitor<String, Void>() {
+ @Override
+ public String visitAggregateFunction(AggregateFunction function,
Void context) {
+ return "aggregate";
+ }
+
+ @Override
+ public String
visitNullableAggregateFunction(NullableAggregateFunction
nullableAggregateFunction,
+ Void context) {
+ return "nullable";
+ }
+ };
+
+ for (Class<?> functionClass : nullableAggregateFunctionClasses()) {
+ List<Method> visitorMethods =
Arrays.stream(AggregateFunctionVisitor.class.getMethods())
+ .filter(method -> method.getParameterCount() == 2)
+ .filter(method ->
method.getParameterTypes()[0].equals(functionClass))
+ .collect(Collectors.toList());
+ Assertions.assertEquals(1, visitorMethods.size(),
functionClass.getName());
+ Assertions.assertEquals("nullable",
visitorMethods.get(0).invoke(visitor, null, null),
+ functionClass.getName());
+ }
+ }
+
+ private static List<Class<?>> nullableAggregateFunctionClasses() throws
ClassNotFoundException {
+ return Arrays.asList(
+ aggregateClass("AIAgg"),
+ aggregateClass("AnyValue"),
+ aggregateClass("Avg"),
+ aggregateClass("AvgWeighted"),
+ aggregateClass("BoolAnd"),
+ aggregateClass("BoolOr"),
+ aggregateClass("BoolXor"),
+ aggregateClass("Corr"),
+ aggregateClass("CorrWelford"),
+ aggregateClass("Covar"),
+ aggregateClass("CovarSamp"),
+ aggregateClass("ExponentialMovingAverage"),
+ aggregateClass("GroupBitAnd"),
+ aggregateClass("GroupBitOr"),
+ aggregateClass("GroupBitXor"),
+ aggregateClass("GroupBitmapXor"),
+ aggregateClass("GroupConcat"),
+ aggregateClass("Max"),
+ aggregateClass("MaxBy"),
+ aggregateClass("Median"),
+ aggregateClass("Min"),
+ aggregateClass("MinBy"),
+ aggregateClass("MultiDistinctGroupConcat"),
+ aggregateClass("MultiDistinctSum"),
+ aggregateClass("Percentile"),
+ aggregateClass("PercentileApprox"),
+ aggregateClass("PercentileApproxWeighted"),
+ aggregateClass("PercentileReservoir"),
+ aggregateClass("Retention"),
+ aggregateClass("Sem"),
+ aggregateClass("SequenceMatch"),
+ aggregateClass("Stddev"),
+ aggregateClass("StddevSamp"),
+ aggregateClass("Sum"),
+ aggregateClass("TopN"),
+ aggregateClass("TopNArray"),
+ aggregateClass("TopNWeighted"),
+ aggregateClass("Variance"),
+ aggregateClass("VarianceSamp"),
+ aggregateClass("WindowFunnel"),
+ aggregateClass("WindowFunnelV2"),
+ combinatorClass("ForEachCombinator"));
+ }
+
+ private static Class<?> aggregateClass(String name) throws
ClassNotFoundException {
+ return Class.forName(AGG_PACKAGE + name);
+ }
+
+ private static Class<?> combinatorClass(String name) throws
ClassNotFoundException {
+ return Class.forName(COMBINATOR_PACKAGE + name);
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]