This is an automated email from the ASF dual-hosted git repository.

godfrey pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 263555c9adcca0abe194e9a6c1d85ec591c304e4
Author: fengli <ldliu...@163.com>
AuthorDate: Mon Feb 27 17:15:47 2023 +0800

    [FLINK-31239][hive] Fix native sum function can't get the corrected value 
when the argument type is string
    
    This closes #22031
---
 .../table/functions/hive/HiveSumAggFunction.java   | 55 +++++++++++++-----
 .../connectors/hive/HiveDialectAggITCase.java      | 66 +++++++++++++++++-----
 .../resources/explain/testSumAggFunctionPlan.out   |  8 +--
 .../planner/expressions/ExpressionBuilder.java     | 10 ++++
 4 files changed, 106 insertions(+), 33 deletions(-)

diff --git 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
index 610a1d93239..48470f997df 100644
--- 
a/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
+++ 
b/flink-connectors/flink-connector-hive/src/main/java/org/apache/flink/table/functions/hive/HiveSumAggFunction.java
@@ -22,13 +22,20 @@ import org.apache.flink.table.api.DataTypes;
 import org.apache.flink.table.api.TableException;
 import org.apache.flink.table.expressions.Expression;
 import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
+import org.apache.flink.table.expressions.ValueLiteralExpression;
 import org.apache.flink.table.types.DataType;
 import org.apache.flink.table.types.inference.CallContext;
 
+import java.math.BigDecimal;
+
 import static 
org.apache.flink.connectors.hive.HiveOptions.TABLE_EXEC_HIVE_NATIVE_AGG_FUNCTION_ENABLED;
 import static 
org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef;
+import static 
org.apache.flink.table.expressions.ApiExpressionUtils.valueLiteral;
+import static org.apache.flink.table.planner.expressions.ExpressionBuilder.and;
+import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.coalesce;
 import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse;
 import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull;
+import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.isTrue;
 import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf;
 import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.tryCast;
 import static 
org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral;
@@ -40,7 +47,10 @@ import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getSc
 public class HiveSumAggFunction extends HiveDeclarativeAggregateFunction {
 
     private final UnresolvedReferenceExpression sum = unresolvedRef("sum");
+    private final UnresolvedReferenceExpression isEmpty = 
unresolvedRef("isEmpty");
+
     private DataType resultType;
+    private ValueLiteralExpression zero;
 
     @Override
     public int operandCount() {
@@ -49,12 +59,12 @@ public class HiveSumAggFunction extends 
HiveDeclarativeAggregateFunction {
 
     @Override
     public UnresolvedReferenceExpression[] aggBufferAttributes() {
-        return new UnresolvedReferenceExpression[] {sum};
+        return new UnresolvedReferenceExpression[] {sum, isEmpty};
     }
 
     @Override
     public DataType[] getAggBufferTypes() {
-        return new DataType[] {getResultType()};
+        return new DataType[] {getResultType(), DataTypes.BOOLEAN()};
     }
 
     @Override
@@ -64,20 +74,19 @@ public class HiveSumAggFunction extends 
HiveDeclarativeAggregateFunction {
 
     @Override
     public Expression[] initialValuesExpressions() {
-        return new Expression[] {/* sum = */ nullOf(getResultType())};
+        return new Expression[] {/* sum = */ nullOf(getResultType()), 
valueLiteral(true)};
     }
 
     @Override
     public Expression[] accumulateExpressions() {
         Expression tryCastOperand = tryCast(operand(0), 
typeLiteral(getResultType()));
+        Expression coalesceSum = coalesce(sum, zero);
         return new Expression[] {
             /* sum = */ ifThenElse(
                     isNull(tryCastOperand),
-                    sum,
-                    ifThenElse(
-                            isNull(sum),
-                            tryCastOperand,
-                            adjustedPlus(getResultType(), sum, 
tryCastOperand)))
+                    coalesceSum,
+                    adjustedPlus(getResultType(), coalesceSum, 
tryCastOperand)),
+            and(isEmpty, isNull(operand(0)))
         };
     }
 
@@ -88,20 +97,19 @@ public class HiveSumAggFunction extends 
HiveDeclarativeAggregateFunction {
 
     @Override
     public Expression[] mergeExpressions() {
+        Expression coalesceSum = coalesce(sum, zero);
         return new Expression[] {
             /* sum = */ ifThenElse(
                     isNull(mergeOperand(sum)),
-                    sum,
-                    ifThenElse(
-                            isNull(sum),
-                            mergeOperand(sum),
-                            adjustedPlus(getResultType(), sum, 
mergeOperand(sum))))
+                    coalesceSum,
+                    adjustedPlus(getResultType(), coalesceSum, 
mergeOperand(sum))),
+            and(isEmpty, mergeOperand(isEmpty))
         };
     }
 
     @Override
     public Expression getValueExpression() {
-        return sum;
+        return ifThenElse(isTrue(isEmpty), nullOf(getResultType()), sum);
     }
 
     @Override
@@ -109,6 +117,7 @@ public class HiveSumAggFunction extends 
HiveDeclarativeAggregateFunction {
         if (resultType == null) {
             checkArgumentNum(callContext.getArgumentDataTypes());
             resultType = 
initResultType(callContext.getArgumentDataTypes().get(0));
+            zero = defaultValue(resultType);
         }
     }
 
@@ -141,4 +150,22 @@ public class HiveSumAggFunction extends 
HiveDeclarativeAggregateFunction {
                                 argsType));
         }
     }
+
+    private ValueLiteralExpression defaultValue(DataType dataType) {
+        switch (dataType.getLogicalType().getTypeRoot()) {
+            case BIGINT:
+                return valueLiteral(0L);
+            case DOUBLE:
+                return valueLiteral(0.0);
+            case DECIMAL:
+                return valueLiteral(
+                        BigDecimal.valueOf(0, 
getScale(dataType.getLogicalType())),
+                        dataType.notNull());
+            default:
+                throw new TableException(
+                        String.format(
+                                "Unsupported type %s is passed when initialize 
the default value.",
+                                dataType));
+        }
+    }
 }
diff --git 
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
 
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
index 11f80f04dcc..896d348cf5c 100644
--- 
a/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
+++ 
b/flink-connectors/flink-connector-hive/src/test/java/org/apache/flink/connectors/hive/HiveDialectAggITCase.java
@@ -76,12 +76,12 @@ public class HiveDialectAggITCase {
     @Test
     public void testSimpleSumAggFunction() throws Exception {
         tableEnv.executeSql(
-                "create table test_sum(x string, y string, z int, d 
decimal(10,5), e float, f double, ts timestamp)");
+                "create table test_sum(x string, y string, g string, z int, d 
decimal(10,5), e float, f double, ts timestamp)");
         tableEnv.executeSql(
-                        "insert into test_sum values (NULL, '2', 1, 1.11, 1.2, 
1.3, '2021-08-04 16:26:33.4'), "
-                                + "(NULL, 'b', 2, 2.22, 2.3, 2.4, '2021-08-07 
16:26:33.4'), "
-                                + "(NULL, '4', 3, 3.33, 3.5, 3.6, '2021-08-08 
16:26:33.4'), "
-                                + "(NULL, NULL, 4, 4.45, 4.7, 4.8, '2021-08-09 
16:26:33.4')")
+                        "insert into test_sum values (NULL, '2', 'b', 1, 1.11, 
1.2, 1.3, '2021-08-04 16:26:33.4'), "
+                                + "(NULL, 'b', 'b', 2, 2.22, 2.3, 2.4, 
'2021-08-07 16:26:33.4'), "
+                                + "(NULL, '4', 'b', 3, 3.33, 3.5, 3.6, 
'2021-08-08 16:26:33.4'), "
+                                + "(NULL, NULL, 'b', 4, 4.45, 4.7, 4.8, 
'2021-08-09 16:26:33.4')")
                 .await();
 
         // test sum with all elements are null
@@ -96,37 +96,43 @@ public class HiveDialectAggITCase {
                         tableEnv.executeSql("select sum(y) from 
test_sum").collect());
         assertThat(result2.toString()).isEqualTo("[+I[6.0]]");
 
-        // test decimal type
+        // test sum string type with all elements can't convert to double, 
result type is double
         List<Row> result3 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select sum(g) from 
test_sum").collect());
+        assertThat(result3.toString()).isEqualTo("[+I[0.0]]");
+
+        // test decimal type
+        List<Row> result4 =
                 CollectionUtil.iteratorToList(
                         tableEnv.executeSql("select sum(d) from 
test_sum").collect());
-        assertThat(result3.toString()).isEqualTo("[+I[11.11000]]");
+        assertThat(result4.toString()).isEqualTo("[+I[11.11000]]");
 
         // test sum int, result type is bigint
-        List<Row> result4 =
+        List<Row> result5 =
                 CollectionUtil.iteratorToList(
                         tableEnv.executeSql("select sum(z) from 
test_sum").collect());
-        assertThat(result4.toString()).isEqualTo("[+I[10]]");
+        assertThat(result5.toString()).isEqualTo("[+I[10]]");
 
         // test float type
-        List<Row> result5 =
+        List<Row> result6 =
                 CollectionUtil.iteratorToList(
                         tableEnv.executeSql("select sum(e) from 
test_sum").collect());
-        float actualFloatValue = ((Double) 
result5.get(0).getField(0)).floatValue();
+        float actualFloatValue = ((Double) 
result6.get(0).getField(0)).floatValue();
         assertThat(actualFloatValue).isEqualTo(11.7f);
 
         // test double type
-        List<Row> result6 =
+        List<Row> result7 =
                 CollectionUtil.iteratorToList(
                         tableEnv.executeSql("select sum(f) from 
test_sum").collect());
-        actualFloatValue = ((Double) result6.get(0).getField(0)).floatValue();
+        actualFloatValue = ((Double) result7.get(0).getField(0)).floatValue();
         assertThat(actualFloatValue).isEqualTo(12.1f);
 
         // test sum string&int type simultaneously
-        List<Row> result7 =
+        List<Row> result8 =
                 CollectionUtil.iteratorToList(
                         tableEnv.executeSql("select sum(y), sum(z) from 
test_sum").collect());
-        assertThat(result7.toString()).isEqualTo("[+I[6.0, 10]]");
+        assertThat(result8.toString()).isEqualTo("[+I[6.0, 10]]");
 
         // test unsupported timestamp type
         String expectedMessage =
@@ -137,6 +143,36 @@ public class HiveDialectAggITCase {
         tableEnv.executeSql("drop table test_sum");
     }
 
+    @Test
+    public void testSumDecimal() throws Exception {
+        tableEnv.executeSql(
+                "create table test_sum_dec(a int, x string, z decimal(10, 5), 
g decimal(18, 5))");
+        tableEnv.executeSql(
+                        "insert into test_sum_dec values (1, 'b', null, null), 
"
+                                + "(1, 'b', 1.2, null), "
+                                + "(2, 'b', null, null), "
+                                + "(2, 'b', null, null),"
+                                + "(4, '1', null, null),"
+                                + "(4, 'b', null, null)")
+                .await();
+
+        List<Row> result =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql("select a, sum(z), sum(g) from 
test_sum_dec group by a")
+                                .collect());
+        assertThat(result.toString())
+                .isEqualTo("[+I[1, 1.20000, null], +I[2, null, null], +I[4, 
null, null]]");
+
+        List<Row> result2 =
+                CollectionUtil.iteratorToList(
+                        tableEnv.executeSql(
+                                        "select a, sum(cast(x as decimal(10, 
3))) from test_sum_dec group by a")
+                                .collect());
+        assertThat(result2.toString()).isEqualTo("[+I[1, 0.000], +I[2, 0.000], 
+I[4, 1.000]]");
+
+        tableEnv.executeSql("drop table test_sum_dec");
+    }
+
     @Test
     public void testSumAggWithGroupKey() throws Exception {
         tableEnv.executeSql(
diff --git 
a/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out
 
b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out
index 702e09fb3f2..95be2ba4e72 100644
--- 
a/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out
+++ 
b/flink-connectors/flink-connector-hive/src/test/resources/explain/testSumAggFunctionPlan.out
@@ -5,13 +5,13 @@ LogicalProject(x=[$0], _o__c1=[$1])
       +- LogicalTableScan(table=[[test-catalog, default, foo]])
 
 == Optimized Physical Plan ==
-HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0) AS $f1])
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0, 
isEmpty$1) AS $f1])
 +- Exchange(distribution=[hash[x]])
-   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS sum$0])
+   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS (sum$0, 
isEmpty$1)])
       +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
 
 == Optimized Execution Plan ==
-HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0) AS $f1])
+HashAggregate(isMerge=[true], groupBy=[x], select=[x, Final_sum(sum$0, 
isEmpty$1) AS $f1])
 +- Exchange(distribution=[hash[x]])
-   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS sum$0])
+   +- LocalHashAggregate(groupBy=[x], select=[x, Partial_sum(y) AS (sum$0, 
isEmpty$1)])
       +- TableSourceScan(table=[[test-catalog, default, foo]], fields=[x, y])
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
index 77cfbaa4baa..bb98614d561 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java
@@ -33,6 +33,7 @@ import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DE
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST;
+import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.COALESCE;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.DIVIDE;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.EQUALS;
@@ -40,6 +41,7 @@ import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.GREATE
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.HIVE_AGG_DECIMAL_PLUS;
 import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.IF;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.IS_NULL;
+import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.IS_TRUE;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.LESS_THAN;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.LESS_THAN_OR_EQUAL;
 import static 
org.apache.flink.table.functions.BuiltInFunctionDefinitions.MINUS;
@@ -97,6 +99,14 @@ public class ExpressionBuilder {
         return call(IS_NULL, input);
     }
 
+    public static UnresolvedCallExpression isTrue(Expression input) {
+        return call(IS_TRUE, input);
+    }
+
+    public static UnresolvedCallExpression coalesce(Expression... args) {
+        return call(COALESCE, args);
+    }
+
     public static UnresolvedCallExpression ifThenElse(
             Expression condition, Expression ifTrue, Expression ifFalse) {
         return call(IF, condition, ifTrue, ifFalse);

Reply via email to