This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 896bf298209 [fix](array function) Support boolean cast for array first 
and last (#64847)
896bf298209 is described below

commit 896bf298209b0388b203fb30e8af699257e4b74d
Author: morrySnow <[email protected]>
AuthorDate: Fri Jun 26 16:39:45 2026 +0800

    [fix](array function) Support boolean cast for array first and last (#64847)
    
    ### What problem does this PR solve?
    
    Problem Summary: `array_first` and `array_last` were rewritten through
    `array_filter` before their lambda result type was checked against the
    boolean filter argument. When the lambda returned a type that can be
    implicitly cast to boolean, analysis reported an error instead of
    applying the normal function coercion. This change gives `array_first`
    and `array_last` their own signatures so the `array_map` result is
    coerced to `array<boolean>` before the functions are rewritten to
    `element_at(array_filter(...))`. It also keeps scalar function visitor
    dispatch aligned with the new function shape.
    
    ### Release note
    
    Fix `array_first` and `array_last` to allow lambda results that can be
    implicitly cast to boolean.
---
 .../nereids/trees/expressions/Expression.java      |  2 +-
 .../expressions/functions/scalar/ArrayFirst.java   | 33 +++++++++++--
 .../expressions/functions/scalar/ArrayLast.java    | 33 +++++++++++--
 .../expressions/visitor/ScalarFunctionVisitor.java |  4 +-
 .../functions/scalar/ArrayFirstLastTest.java       | 56 ++++++++++++++++++++++
 .../array_functions/test_array_first.out           |  3 ++
 .../array_functions/test_array_last.out            |  3 ++
 .../array_functions/test_array_first.groovy        |  3 +-
 .../array_functions/test_array_last.groovy         |  3 +-
 9 files changed, 125 insertions(+), 15 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
index 3ae49261636..5c900ea40cd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
@@ -260,7 +260,7 @@ public abstract class Expression extends 
AbstractTreeNode<Expression> implements
             DataType expected = expectedTypes.get(i);
             if (!checkInputDataTypesWithExpectType(input.getDataType(), 
expected)) {
                 errorMessages.add(String.format("argument %d requires %s type, 
however '%s' is of %s type",
-                        i + 1, expected.simpleString(), input.toSql(), 
input.getDataType().simpleString()));
+                        i + 1, expected, input.toSql(), input.getDataType()));
             }
         }
         if (!errorMessages.isEmpty()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java
index 5410de371a7..3d186812d42 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirst.java
@@ -18,24 +18,41 @@
 package org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
+import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze;
 import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.BooleanType;
+import org.apache.doris.nereids.types.coercion.AnyDataType;
+
+import com.google.common.collect.ImmutableList;
 
 import java.util.List;
 
 /**
  * ScalarFunction 'array_first'.
  */
-public class ArrayFirst extends ElementAt
-        implements HighOrderFunction {
+public class ArrayFirst extends ScalarFunction
+        implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, 
RewriteWhenAnalyze {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX),
+                    ArrayType.of(BooleanType.INSTANCE))
+    );
 
     /**
      * constructor with arguments.
-     * array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, 
...), 1)
      */
     public ArrayFirst(Expression arg) {
-        super(new ArrayFilter(arg), new BigIntLiteral(1));
+        super("array_first", arg instanceof Lambda ? arg.child(1).child(0) : 
arg, new ArrayMap(arg));
+        if (!(arg instanceof Lambda)) {
+            throw new AnalysisException(
+                    String.format("The 1st arg of %s must be lambda but is 
%s", getName(), arg));
+        }
     }
 
     /** constructor for withChildren and reuse signature */
@@ -44,7 +61,7 @@ public class ArrayFirst extends ElementAt
     }
 
     @Override
-    public ElementAt withChildren(List<Expression> children) {
+    public ArrayFirst withChildren(List<Expression> children) {
         return new ArrayFirst(getFunctionParams(children));
     }
 
@@ -57,4 +74,10 @@ public class ArrayFirst extends ElementAt
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitArrayFirst(this, context);
     }
+
+    // array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, 
...), 1)
+    @Override
+    public Expression rewriteWhenAnalyze() {
+        return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), 
new BigIntLiteral(1));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java
index b9f5650156f..45ab415b6ff 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayLast.java
@@ -18,24 +18,41 @@
 package org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.catalog.FunctionSignature;
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import 
org.apache.doris.nereids.trees.expressions.functions.PropagateNullLiteral;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
+import org.apache.doris.nereids.trees.expressions.functions.RewriteWhenAnalyze;
 import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.BooleanType;
+import org.apache.doris.nereids.types.coercion.AnyDataType;
+
+import com.google.common.collect.ImmutableList;
 
 import java.util.List;
 
 /**
  * ScalarFunction 'array_last'.
  */
-public class ArrayLast extends ElementAt
-        implements HighOrderFunction {
+public class ArrayLast extends ScalarFunction
+        implements HighOrderFunction, PropagateNullLiteral, PropagateNullable, 
RewriteWhenAnalyze {
+
+    public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
+            
FunctionSignature.retArgType(0).args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX),
+                    ArrayType.of(BooleanType.INSTANCE))
+    );
 
     /**
      * constructor with arguments.
-     * array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 
-1)
      */
     public ArrayLast(Expression arg) {
-        super(new ArrayFilter(arg), new BigIntLiteral(-1));
+        super("array_last", arg instanceof Lambda ? arg.child(1).child(0) : 
arg, new ArrayMap(arg));
+        if (!(arg instanceof Lambda)) {
+            throw new AnalysisException(
+                    String.format("The 1st arg of %s must be lambda but is 
%s", getName(), arg));
+        }
     }
 
     /** constructor for withChildren and reuse signature */
@@ -49,7 +66,7 @@ public class ArrayLast extends ElementAt
     }
 
     @Override
-    public ElementAt withChildren(List<Expression> children) {
+    public ArrayLast withChildren(List<Expression> children) {
         return new ArrayLast(getFunctionParams(children));
     }
 
@@ -57,4 +74,10 @@ public class ArrayLast extends ElementAt
     public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
         return visitor.visitArrayLast(this, context);
     }
+
+    // array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 
-1)
+    @Override
+    public Expression rewriteWhenAnalyze() {
+        return new ElementAt(new ArrayFilter(getArgument(0), getArgument(1)), 
new BigIntLiteral(-1));
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
index 499eed6b11b..ed6a8bcf939 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/ScalarFunctionVisitor.java
@@ -713,7 +713,7 @@ public interface ScalarFunctionVisitor<R, C> {
     }
 
     default R visitArrayFirst(ArrayFirst arrayFirst, C context) {
-        return visitElementAt(arrayFirst, context);
+        return visitScalarFunction(arrayFirst, context);
     }
 
     default R visitArrayFirstIndex(ArrayFirstIndex arrayFirstIndex, C context) 
{
@@ -729,7 +729,7 @@ public interface ScalarFunctionVisitor<R, C> {
     }
 
     default R visitArrayLast(ArrayLast arrayLast, C context) {
-        return visitElementAt(arrayLast, context);
+        return visitScalarFunction(arrayLast, context);
     }
 
     default R visitArrayLastIndex(ArrayLastIndex arrayLastIndex, C context) {
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java
new file mode 100644
index 00000000000..1d7b285fdd6
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayFirstLastTest.java
@@ -0,0 +1,56 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.trees.expressions.functions.scalar;
+
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.types.ArrayType;
+import org.apache.doris.nereids.types.BooleanType;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class ArrayFirstLastTest extends ExpressionRewriteTestHelper {
+
+    @Test
+    public void testArrayFirstLambdaResultCanCastToBoolean() {
+        assertLambdaResultCastToBoolean("array_first(x -> x, [0, 1])", 1);
+    }
+
+    @Test
+    public void testArrayLastLambdaResultCanCastToBoolean() {
+        assertLambdaResultCastToBoolean("array_last(x -> x, [0, 1])", -1);
+    }
+
+    private void assertLambdaResultCastToBoolean(String sql, long 
expectedIndex) {
+        Expression analyzed = typeCoercion(PARSER.parseExpression(sql));
+        Assertions.assertTrue(analyzed instanceof ElementAt);
+
+        ElementAt elementAt = (ElementAt) analyzed;
+        Assertions.assertTrue(elementAt.left() instanceof ArrayFilter);
+        Assertions.assertEquals(expectedIndex, ((BigIntLiteral) 
elementAt.right()).getValue());
+
+        ArrayFilter arrayFilter = (ArrayFilter) elementAt.left();
+        Expression filterResult = arrayFilter.child(1);
+        Assertions.assertTrue(filterResult instanceof Cast);
+        Assertions.assertEquals(ArrayType.of(BooleanType.INSTANCE), 
filterResult.getDataType());
+        Assertions.assertTrue(((Cast) filterResult).child() instanceof 
ArrayMap);
+    }
+}
diff --git 
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out
 
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out
index bc4f80a9576..24a5dc6f88e 100644
--- 
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out
+++ 
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_first.out
@@ -14,6 +14,9 @@ b
 -- !select_05 --
 10.2
 
+-- !select_lambda_result_cast --
+1
+
 -- !select_06 --
 0      [2]     ["123", "124", "125"]
 1      [1, 2, 3, 4, 5] ["234", "124", "125"]
diff --git 
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out
 
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out
index b00916ba420..45c56736f0f 100644
--- 
a/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out
+++ 
b/regression-test/data/query_p0/sql_functions/array_functions/test_array_last.out
@@ -14,6 +14,9 @@ c
 -- !select_05 --
 5.3
 
+-- !select_lambda_result_cast --
+2
+
 -- !select_06 --
 0      [2]     ["123", "124", "125"]
 1      [1, 2, 3, 4, 5] ["234", "124", "125"]
diff --git 
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy
 
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy
index 2b4fc078602..ee11b7ab108 100644
--- 
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_first.groovy
@@ -45,9 +45,10 @@ suite("test_array_first") {
         qt_select_03 " select array_first(x -> x>=5,[1,2,3,4,5]);"
         qt_select_04 " select array_first(x -> x > 'abc', ['a','b','c']);"
         qt_select_05 " select array_first(x -> x > 5.2 , [10.2, 5.3, 4]);"
+        qt_select_lambda_result_cast " select array_first(x -> x, [0, 1, 2]);"
 
         qt_select_06  "select * from ${tableName} order by id;"
         
         qt_select_07 " select array_first(x->x>3,c_array1), array_first(x-> 
x>'124',c_array2) from test_array_first order by id;"        
         sql "DROP TABLE IF EXISTS ${tableName}"
-}
\ No newline at end of file
+}
diff --git 
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy
 
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy
index fe5fdec9ffc..82df24c8eab 100644
--- 
a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_last.groovy
@@ -45,9 +45,10 @@ suite("test_array_last") {
         qt_select_03 " select array_last(x -> x>=5,[1,2,3,4,5]);"
         qt_select_04 " select array_last(x -> x > 'abc', ['a','b','c']);"
         qt_select_05 " select array_last(x -> x > 5.2 , [10.2, 5.3, 4]);"
+        qt_select_lambda_result_cast " select array_last(x -> x, [0, 1, 2]);"
 
         qt_select_06  "select * from ${tableName} order by id;"
         
         qt_select_07 " select array_last(x->x>3,c_array1), array_last(x-> 
x>'124',c_array2) from test_array_last order by id;"        
         sql "DROP TABLE IF EXISTS ${tableName}"
-}
\ No newline at end of file
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to