This is an automated email from the ASF dual-hosted git repository.

gabriellee pushed a commit to branch branch-1.2-lts
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-1.2-lts by this push:
     new 3314c48cb1 [Bug](DecimalV3) fix decimalv3 functions (#19801) (#20364)
3314c48cb1 is described below

commit 3314c48cb15328112d8450ac52efc0cb0db12972
Author: Gabriel <[email protected]>
AuthorDate: Fri Jun 2 15:06:43 2023 +0800

    [Bug](DecimalV3) fix decimalv3 functions (#19801) (#20364)
---
 .../aggregate_functions/aggregate_function_topn.h  |  9 +-
 .../main/java/org/apache/doris/analysis/Expr.java  | 15 ++++
 .../apache/doris/analysis/FunctionCallExpr.java    | 95 +++++++++++++++++++---
 3 files changed, 102 insertions(+), 17 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h 
b/be/src/vec/aggregate_functions/aggregate_function_topn.h
index cc5d3f6d3d..c4ccbbe1ba 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_topn.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h
@@ -39,8 +39,7 @@ namespace doris::vectorized {
 // space-saving algorithm
 template <typename T>
 struct AggregateFunctionTopNData {
-    using ColVecType =
-            std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, 
ColumnVector<T>>;
+    using ColVecType = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<T>, ColumnVector<T>>;
     void set_paramenters(int input_top_num, int space_expand_rate = 50) {
         top_num = input_top_num;
         capacity = (uint64_t)top_num * space_expand_rate;
@@ -206,8 +205,7 @@ struct AggregateFunctionTopNImplIntInt {
 //for topn_array agg
 template <typename T, bool has_default_param>
 struct AggregateFunctionTopNImplArray {
-    using ColVecType =
-            std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, 
ColumnVector<T>>;
+    using ColVecType = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<T>, ColumnVector<T>>;
     static void add(AggregateFunctionTopNData<T>& __restrict place, const 
IColumn** columns,
                     size_t row_num) {
         if constexpr (has_default_param) {
@@ -231,8 +229,7 @@ struct AggregateFunctionTopNImplArray {
 //for topn_weighted agg
 template <typename T, bool has_default_param>
 struct AggregateFunctionTopNImplWeight {
-    using ColVecType =
-            std::conditional_t<IsDecimalNumber<T>, ColumnDecimal<Decimal128>, 
ColumnVector<T>>;
+    using ColVecType = std::conditional_t<IsDecimalNumber<T>, 
ColumnDecimal<T>, ColumnVector<T>>;
     static void add(AggregateFunctionTopNData<T>& __restrict place, const 
IColumn** columns,
                     size_t row_num) {
         if constexpr (has_default_param) {
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
index 4c76d1bade..abed221866 100755
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Expr.java
@@ -965,6 +965,21 @@ public abstract class Expr extends TreeNode<Expr> 
implements ParseNode, Cloneabl
         }
     }
 
+    public static Type getAssignmentCompatibleType(List<Expr> children) {
+        Type assignmentCompatibleType = Type.INVALID;
+        for (int i = 0; i < children.size()
+                && (assignmentCompatibleType.isDecimalV3() || 
assignmentCompatibleType.isDatetimeV2()
+                || assignmentCompatibleType.isInvalid()); i++) {
+            if (children.get(i) instanceof NullLiteral) {
+                continue;
+            }
+            assignmentCompatibleType = assignmentCompatibleType.isInvalid() ? 
children.get(i).type
+                    : 
ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, 
children.get(i).type,
+                    true);
+        }
+        return assignmentCompatibleType;
+    }
+
     // Convert this expr into msg (excluding children), which requires setting
     // msg.op as well as the expr-specific field.
     protected abstract void toThrift(TExprNode msg);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
index 8e3ac7a1f6..e35eec84d2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java
@@ -140,16 +140,47 @@ public class FunctionCallExpr extends Expr {
         PRECISION_INFER_RULE.put("if", (children, returnType) -> {
             Preconditions.checkArgument(children != null && children.size() == 
3);
             if (children.get(1).getType().isDecimalV3() && 
children.get(2).getType().isDecimalV3()) {
-                return ScalarType.createDecimalV3Type(
-                        Math.max(((ScalarType) 
children.get(1).getType()).decimalPrecision(),
-                                ((ScalarType) 
children.get(2).getType()).decimalPrecision()),
-                        Math.max(((ScalarType) 
children.get(1).getType()).decimalScale(),
-                                ((ScalarType) 
children.get(2).getType()).decimalScale()));
+                return Expr.getAssignmentCompatibleType(children.subList(1, 
children.size()));
             } else if (children.get(1).getType().isDatetimeV2() && 
children.get(2).getType().isDatetimeV2()) {
-                return ((ScalarType) children.get(1).getType())
-                        .decimalScale() > ((ScalarType) 
children.get(2).getType()).decimalScale()
-                                ? children.get(1).getType()
-                                : children.get(2).getType();
+                return Expr.getAssignmentCompatibleType(children.subList(1, 
children.size()));
+            } else {
+                return returnType;
+            }
+        });
+
+        PRECISION_INFER_RULE.put("ifnull", (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() == 
2);
+            if (children.get(0).getType().isDecimalV3() && 
children.get(1).getType().isDecimalV3()) {
+                return Expr.getAssignmentCompatibleType(children);
+            } else if (children.get(0).getType().isDatetimeV2() && 
children.get(1).getType().isDatetimeV2()) {
+                return Expr.getAssignmentCompatibleType(children);
+            } else {
+                return returnType;
+            }
+        });
+
+        PRECISION_INFER_RULE.put("nvl", (children, returnType) -> {
+            Preconditions.checkArgument(children != null && children.size() == 
2);
+            if (children.get(0).getType().isDecimalV3() && 
children.get(1).getType().isDecimalV3()) {
+                return Expr.getAssignmentCompatibleType(children);
+            } else if (children.get(0).getType().isDatetimeV2() && 
children.get(1).getType().isDatetimeV2()) {
+                return Expr.getAssignmentCompatibleType(children);
+            } else {
+                return returnType;
+            }
+        });
+
+        PRECISION_INFER_RULE.put("coalesce", (children, returnType) -> {
+            boolean isDecimalV3 = true;
+            boolean isDateTimeV2 = true;
+
+            Type assignmentCompatibleType = 
Expr.getAssignmentCompatibleType(children);
+            for (Expr child : children) {
+                isDecimalV3 = isDecimalV3 && child.getType().isDecimalV3();
+                isDateTimeV2 = isDateTimeV2 && child.getType().isDatetimeV2();
+            }
+            if ((isDecimalV3 || isDateTimeV2) && 
assignmentCompatibleType.isValid()) {
+                return assignmentCompatibleType;
             } else {
                 return returnType;
             }
@@ -1222,21 +1253,63 @@ public class FunctionCallExpr extends Expr {
             Type[] childTypes = collectChildReturnTypes();
             Type assignmentCompatibleType = 
ScalarType.getAssignmentCompatibleType(childTypes[1], childTypes[2], true);
             if (assignmentCompatibleType.isDecimalV3()) {
-                if (childTypes[1].isDecimalV3() && !((ScalarType) 
childTypes[1]).equals(assignmentCompatibleType)) {
+                if (assignmentCompatibleType.isDecimalV3() && 
!childTypes[1].equals(assignmentCompatibleType)) {
                     uncheckedCastChild(assignmentCompatibleType, 1);
                 }
-                if (childTypes[2].isDecimalV3() && !((ScalarType) 
childTypes[2]).equals(assignmentCompatibleType)) {
+                if (assignmentCompatibleType.isDecimalV3() && 
!childTypes[2].equals(assignmentCompatibleType)) {
                     uncheckedCastChild(assignmentCompatibleType, 2);
                 }
             }
             childTypes[1] = assignmentCompatibleType;
             childTypes[2] = assignmentCompatibleType;
+
+            if (childTypes[1].isDecimalV3() && childTypes[2].isDecimalV3()) {
+                argTypes[1] = assignmentCompatibleType;
+                argTypes[2] = assignmentCompatibleType;
+            }
             fn = getBuiltinFunction(fnName.getFunction(), childTypes,
                     Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
             if (assignmentCompatibleType.isDatetimeV2()) {
                 fn.setReturnType(assignmentCompatibleType);
             }
 
+        } else if (fnName.getFunction().equalsIgnoreCase("ifnull") || 
fnName.getFunction().equalsIgnoreCase("nvl")) {
+            Type[] childTypes = collectChildReturnTypes();
+            Type assignmentCompatibleType = 
ScalarType.getAssignmentCompatibleType(childTypes[0], childTypes[1], true);
+            if (assignmentCompatibleType.isDecimalV3()) {
+                if (assignmentCompatibleType.isDecimalV3() && 
!childTypes[0].equals(assignmentCompatibleType)) {
+                    uncheckedCastChild(assignmentCompatibleType, 0);
+                }
+                if (assignmentCompatibleType.isDecimalV3() && 
!childTypes[1].equals(assignmentCompatibleType)) {
+                    uncheckedCastChild(assignmentCompatibleType, 1);
+                }
+            }
+            childTypes[0] = assignmentCompatibleType;
+            childTypes[1] = assignmentCompatibleType;
+
+            if (childTypes[1].isDecimalV3() && childTypes[0].isDecimalV3()) {
+                argTypes[1] = assignmentCompatibleType;
+                argTypes[0] = assignmentCompatibleType;
+            }
+            fn = getBuiltinFunction(fnName.getFunction(), childTypes,
+                    Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
+        } else if (fnName.getFunction().equalsIgnoreCase("coalesce") && 
children.size() > 1) {
+            Type[] childTypes = collectChildReturnTypes();
+            Type assignmentCompatibleType = childTypes[0];
+            for (int i = 1; i < childTypes.length && 
assignmentCompatibleType.isDecimalV3(); i++) {
+                assignmentCompatibleType =
+                        
ScalarType.getAssignmentCompatibleType(assignmentCompatibleType, childTypes[i], 
true);
+            }
+            if (assignmentCompatibleType.isDecimalV3()) {
+                for (int i = 0; i < childTypes.length; i++) {
+                    if (assignmentCompatibleType.isDecimalV3() && 
!childTypes[i].equals(assignmentCompatibleType)) {
+                        uncheckedCastChild(assignmentCompatibleType, i);
+                        argTypes[i] = assignmentCompatibleType;
+                    }
+                }
+            }
+            fn = getBuiltinFunction(fnName.getFunction(), childTypes,
+                    Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF);
         } else if 
(AggregateFunction.SUPPORT_ORDER_BY_AGGREGATE_FUNCTION_NAME_SET.contains(
                 fnName.getFunction().toLowerCase())) {
             // order by elements add as child like windows function. so if we 
get the


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to