This is an automated email from the ASF dual-hosted git repository. twalthr pushed a commit to branch release-1.14 in repository https://gitbox.apache.org/repos/asf/flink.git
commit bafb0b4c2377d6d502ed9dba8853631ebf16cfb7 Author: Marios Trivyzas <mat...@gmail.com> AuthorDate: Thu Nov 4 13:03:26 2021 +0100 [FLINK-24691][table-planner] Fix decimal precision for SUM Since SUM is using internally `plus()` operator to implement the sum aggregation, the decimal return type calculated by `LogicalTypeMerging#findSumAggType()` gets overriden by the calculation for the `plus()` operator done by `LogicalTypeMerging#findAdditionDecimalType()`. To prevent this add a special `aggDecimalPlus()` operator to be used exclusively for aggregate function to avoid overriding their calculated precision. This closes #17634. --- .../functions/BuiltInFunctionDefinitions.java | 19 +++++++ .../strategies/AggDecimalPlusTypeStrategy.java | 66 ++++++++++++++++++++++ .../strategies/SpecificTypeStrategies.java | 3 + .../types/logical/utils/LogicalTypeMerging.java | 14 +++-- .../logical/utils/LogicalTypeMergingTest.java | 10 +++- .../planner/expressions/ExpressionBuilder.java | 9 +++ .../functions/aggfunctions/SumAggFunction.java | 41 +++++++++----- .../table/planner/codegen/ExprCodeGenerator.scala | 19 ++++--- .../runtime/stream/sql/AggregateITCase.scala | 28 +++++++++ .../runtime/stream/table/AggregateITCase.scala | 23 ++++++++ 10 files changed, 204 insertions(+), 28 deletions(-) diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java index 022b7b1..d1695f6 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java @@ -29,6 +29,7 @@ import org.apache.flink.table.types.inference.InputTypeStrategies; import org.apache.flink.table.types.inference.TypeStrategies; import org.apache.flink.table.types.inference.strategies.SpecificInputTypeStrategies; import org.apache.flink.table.types.inference.strategies.SpecificTypeStrategies; +import org.apache.flink.table.types.logical.LogicalType; import org.apache.flink.table.types.logical.LogicalTypeFamily; import org.apache.flink.table.types.logical.LogicalTypeRoot; import org.apache.flink.table.types.logical.StructuredType.StructuredComparison; @@ -743,6 +744,24 @@ public final class BuiltInFunctionDefinitions { explicit(DataTypes.STRING())))) .build(); + /** + * Special "+" operator used internally by {@code SumAggFunction} to implement SUM aggregation + * on a Decimal type. Uses the {@link LogicalTypeMerging#findSumAggType(LogicalType)} to avoid + * the normal {@link #PLUS} override the special calculation for precision and scale needed by + * SUM. + */ + public static final BuiltInFunctionDefinition AGG_DECIMAL_PLUS = + BuiltInFunctionDefinition.newBuilder() + .name("AGG_DECIMAL_PLUS") + .kind(SCALAR) + .inputTypeStrategy( + sequence( + logical(LogicalTypeRoot.DECIMAL), + logical(LogicalTypeRoot.DECIMAL))) + .outputTypeStrategy(SpecificTypeStrategies.AGG_DECIMAL_PLUS) + .runtimeProvided() + .build(); + /** Combines numeric subtraction and "datetime - interval" arithmetic. */ public static final BuiltInFunctionDefinition MINUS = BuiltInFunctionDefinition.newBuilder() diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java new file mode 100644 index 0000000..23be242 --- /dev/null +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/AggDecimalPlusTypeStrategy.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.types.inference.strategies; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.inference.CallContext; +import org.apache.flink.table.types.inference.TypeStrategy; +import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; +import org.apache.flink.table.types.utils.TypeConversions; +import org.apache.flink.util.Preconditions; + +import java.util.List; +import java.util.Optional; + +/** + * Type strategy that returns the result decimal addition used, internally by {@code SumAggFunction} + * to implement SUM aggregation on a Decimal type. Uses the {@link + * LogicalTypeMerging#findSumAggType(LogicalType)} and prevents the {@link DecimalPlusTypeStrategy} + * from overriding the special calculation for precision and scale needed by SUM. + */ +@Internal +class AggDecimalPlusTypeStrategy implements TypeStrategy { + + private static final String ERROR_MSG = + "Both args of " + + AggDecimalPlusTypeStrategy.class.getSimpleName() + + " should be of type[" + + DecimalType.class.getSimpleName() + + "]"; + + @Override + public Optional<DataType> inferType(CallContext callContext) { + final List<DataType> argumentDataTypes = callContext.getArgumentDataTypes(); + final LogicalType addend1 = argumentDataTypes.get(0).getLogicalType(); + final LogicalType addend2 = argumentDataTypes.get(1).getLogicalType(); + + Preconditions.checkArgument( + LogicalTypeChecks.hasRoot(addend1, LogicalTypeRoot.DECIMAL), ERROR_MSG); + Preconditions.checkArgument( + LogicalTypeChecks.hasRoot(addend2, LogicalTypeRoot.DECIMAL), ERROR_MSG); + + return Optional.of( + TypeConversions.fromLogicalToDataType(LogicalTypeMerging.findSumAggType(addend2))); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java index 69b1e5f..c0e4dee 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/inference/strategies/SpecificTypeStrategies.java @@ -61,6 +61,9 @@ public final class SpecificTypeStrategies { /** See {@link DecimalPlusTypeStrategy}. */ public static final TypeStrategy DECIMAL_PLUS = new DecimalPlusTypeStrategy(); + /** See {@link AggDecimalPlusTypeStrategy}. */ + public static final TypeStrategy AGG_DECIMAL_PLUS = new AggDecimalPlusTypeStrategy(); + /** See {@link DecimalScale0TypeStrategy}. */ public static final TypeStrategy DECIMAL_SCALE_0 = new DecimalScale0TypeStrategy(); diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java index 8cf5d96..e7aab32 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/logical/utils/LogicalTypeMerging.java @@ -304,15 +304,19 @@ public final class LogicalTypeMerging { * * <p>https://docs.microsoft.com/en-us/sql/t-sql/data-types/precision-scale-and-length-transact-sql * - * <p>The rules although inspired by SQL Server they are not followed 100%, instead the approach - * of Spark/Hive is followed for adjusting the precision. + * <p>The rules (although inspired by SQL Server) are not followed 100%, instead the approach of + * Spark/Hive is followed for adjusting the precision. * * <p>http://www.openkb.info/2021/05/understand-decimal-precision-and-scale.html * - * <p>For (38, 8) + (32, 8) -> (39, 8) (If precision is infinite) // integral part: 31 + * <p>For (38, 8) + (32, 8) -> (39, 8) (The rules for addition, initially calculate a decimal + * type, assuming its precision is infinite) results in a decimal with integral part of 31 + * digits. * - * <p>The rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead - * we follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31 + * <p>This method is called subsequently to adjust the resulting decimal since the maximum + * allowed precision is 38 (so far a precision of 39 is calculated in the first step). So, the + * rounding for SQL Server would be: (39, 8) -> (38, 8) // integral part: 30, but instead we + * follow the Hive/Spark approach which gives: (39, 8) -> (38, 7) // integral part: 31 */ private static DecimalType adjustPrecisionScale(int precision, int scale) { if (precision <= DecimalType.MAX_PRECISION) { diff --git a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java index 22c4bf8..8217130 100644 --- a/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java +++ b/flink-table/flink-table-common/src/test/java/org/apache/flink/table/types/logical/utils/LogicalTypeMergingTest.java @@ -19,6 +19,7 @@ package org.apache.flink.table.types.logical.utils; import org.apache.flink.table.types.logical.DecimalType; +import org.apache.flink.table.types.logical.LogicalType; import org.junit.Test; @@ -27,7 +28,14 @@ import java.util.List; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; -/** Tests for {@link LogicalTypeMerging#findCommonType(List)}. */ +/** + * Tests for {@link LogicalTypeMerging} for finding the result decimal type for the various + * operations, e.g.: {@link LogicalTypeMerging#findSumAggType(LogicalType)}, {@link + * LogicalTypeMerging#findAdditionDecimalType(int, int, int, int)}, etc. + * + * <p>For {@link LogicalTypeMerging#findCommonType(List)} tests please check {@link + * org.apache.flink.table.types.LogicalCommonTypeTest} + */ public class LogicalTypeMergingTest { @Test diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java index 9effb05..1e58d93 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java @@ -18,6 +18,7 @@ package org.apache.flink.table.planner.expressions; +import org.apache.flink.annotation.Internal; import org.apache.flink.table.expressions.ApiExpressionUtils; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.TypeLiteralExpression; @@ -28,6 +29,7 @@ import org.apache.flink.table.types.DataType; import java.util.List; +import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.AND; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CAST; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.CONCAT; @@ -100,6 +102,13 @@ public class ExpressionBuilder { return call(PLUS, input1, input2); } + // Used only for implementing the SumAggFunction to avoid overriding decimal precision/scale + // calculation for sum with the rules applied for the normal plus + @Internal + public static UnresolvedCallExpression aggDecimalPlus(Expression input1, Expression input2) { + return call(AGG_DECIMAL_PLUS, input1, input2); + } + public static UnresolvedCallExpression minus(Expression input1, Expression input2) { return call(MINUS, input1, input2); } diff --git a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java index 4e93800..ba3ebbb 100644 --- a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java +++ b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SumAggFunction.java @@ -28,16 +28,16 @@ import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.utils.LogicalTypeMerging; import static org.apache.flink.table.expressions.ApiExpressionUtils.unresolvedRef; -import static org.apache.flink.table.planner.expressions.ExpressionBuilder.cast; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.aggDecimalPlus; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.ifThenElse; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.isNull; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; -import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; /** built-in sum aggregate function. */ public abstract class SumAggFunction extends DeclarativeAggregateFunction { - private UnresolvedReferenceExpression sum = unresolvedRef("sum"); + + protected UnresolvedReferenceExpression sum = unresolvedRef("sum"); @Override public int operandCount() { @@ -62,11 +62,13 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction { @Override public Expression[] accumulateExpressions() { return new Expression[] { - /* sum = */ adjustSumType( + /* sum = */ ifThenElse( + isNull(operand(0)), + sum, ifThenElse( isNull(operand(0)), sum, - ifThenElse(isNull(sum), operand(0), plus(sum, operand(0))))) + ifThenElse(isNull(sum), operand(0), doPlus(sum, operand(0))))) }; } @@ -79,17 +81,16 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction { @Override public Expression[] mergeExpressions() { return new Expression[] { - /* sum = */ adjustSumType( - ifThenElse( - isNull(mergeOperand(sum)), - sum, - ifThenElse( - isNull(sum), mergeOperand(sum), plus(sum, mergeOperand(sum))))) + /* sum = */ ifThenElse( + isNull(mergeOperand(sum)), + sum, + ifThenElse(isNull(sum), mergeOperand(sum), doPlus(sum, mergeOperand(sum)))) }; } - private UnresolvedCallExpression adjustSumType(UnresolvedCallExpression sumExpr) { - return cast(sumExpr, typeLiteral(getResultType())); + protected UnresolvedCallExpression doPlus( + UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) { + return plus(arg1, arg2); } @Override @@ -149,6 +150,7 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction { /** Built-in Decimal Sum aggregate function. */ public static class DecimalSumAggFunction extends SumAggFunction { private DecimalType decimalType; + private DataType returnType; public DecimalSumAggFunction(DecimalType decimalType) { this.decimalType = decimalType; @@ -156,8 +158,17 @@ public abstract class SumAggFunction extends DeclarativeAggregateFunction { @Override public DataType getResultType() { - DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType); - return DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale()); + if (returnType == null) { + DecimalType sumType = (DecimalType) LogicalTypeMerging.findSumAggType(decimalType); + returnType = DataTypes.DECIMAL(sumType.getPrecision(), sumType.getScale()); + } + return returnType; + } + + @Override + protected UnresolvedCallExpression doPlus( + UnresolvedReferenceExpression arg1, UnresolvedReferenceExpression arg2) { + return aggDecimalPlus(arg1, arg2); } } } diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala index 4cea7d2..32cfe0f 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala @@ -806,21 +806,26 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean) case bsf: BridgingSqlFunction => bsf.getDefinition match { - case functionDefinition : FunctionDefinition - if functionDefinition eq BuiltInFunctionDefinitions.CURRENT_WATERMARK => + case BuiltInFunctionDefinitions.CURRENT_WATERMARK => generateWatermark(ctx, contextTerm, resultType) - case functionDefinition : FunctionDefinition - if functionDefinition eq BuiltInFunctionDefinitions.GREATEST => + + case BuiltInFunctionDefinitions.GREATEST => operands.foreach { operand => requireComparable(operand) } generateGreatestLeast(resultType, operands) - case functionDefinition : FunctionDefinition - if functionDefinition eq BuiltInFunctionDefinitions.LEAST => + + case BuiltInFunctionDefinitions.LEAST => operands.foreach { operand => requireComparable(operand) } - generateGreatestLeast(resultType, operands, false) + generateGreatestLeast(resultType, operands, greatest = false) + + case BuiltInFunctionDefinitions.AGG_DECIMAL_PLUS => + val left = operands.head + val right = operands(1) + generateBinaryArithmeticOperator(ctx, "+", resultType, left, right) + case _ => new BridgingSqlFunctionCallGen(call).generate(ctx, operands, resultType) } diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala index 8db0b97..64bec0b 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala @@ -397,6 +397,34 @@ class AggregateITCase( } @Test + def testPrecisionForSumAggregationOnDecimal(): Unit = { + var t = tEnv.sqlQuery( + "select cast(sum(1.03520274) as DECIMAL(32, 8)), " + + "cast(sum(12345.035202748654) AS DECIMAL(30, 20)), " + + "cast(sum(12.345678901234567) AS DECIMAL(25, 22))") + var sink = new TestingRetractSink + t.toRetractStream[Row].addSink(sink).setParallelism(1) + env.execute() + var expected = List("1.03520274,12345.03520274865400000000,12.3456789012345670000000") + assertEquals(expected, sink.getRetractResults) + + val data = new mutable.MutableList[(Double, Int)] + data .+= ((1.11111111, 1)) + data .+= ((1.11111111, 2)) + env.setParallelism(1) + + t = failingDataSource(data).toTable(tEnv, 'a, 'b) + tEnv.registerTable("T", t) + + t = tEnv.sqlQuery("select sum(cast(a as decimal(32, 8))) from T") + sink = new TestingRetractSink + t.toRetractStream[Row].addSink(sink) + env.execute() + expected = List("2.22222222") + assertEquals(expected, sink.getRetractResults) + } + + @Test def testGroupByAgg(): Unit = { val data = new mutable.MutableList[(Int, Long, String)] data.+=((1, 1L, "A")) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala index 89454c4..0c3a99a 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/table/AggregateITCase.scala @@ -24,6 +24,7 @@ import org.apache.flink.api.scala._ import org.apache.flink.table.api._ import org.apache.flink.table.api.bridge.scala._ import org.apache.flink.table.api.internal.TableEnvironmentInternal +import org.apache.flink.table.api.DataTypes.DECIMAL import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.{CountDistinct, DataViewTestAgg, WeightedAvg} import org.apache.flink.table.planner.runtime.utils.StreamingWithStateTestBase.StateBackendMode import org.apache.flink.table.planner.runtime.utils.TestData._ @@ -395,4 +396,26 @@ class AggregateITCase(mode: StateBackendMode) extends StreamingWithStateTestBase val expected = mutable.MutableList("1,1", "2,3", "3,6", "4,10", "5,15", "6,21") assertEquals(expected.sorted, sink.getRetractResults.sorted) } + + @Test + def testPrecisionForSumAggregationOnDecimal(): Unit = { + val data = new mutable.MutableList[(Double, Double, Double, Double)] + data.+=((1.03520274, 12345.035202748654, 12.345678901234567, 1.11111111)) + data.+=((0, 0, 0, 1.11111111)) + val t = failingDataSource(data).toTable(tEnv, 'a, 'b, 'c, 'd) + + val results = t + .select('a.cast(DECIMAL(32, 8)).sum as 'a, + 'b.cast(DECIMAL(30, 20)).sum as 'b, + 'c.cast(DECIMAL(25, 20)).sum as 'c, + 'd.cast(DECIMAL(32, 8)).sum as 'd) + .toRetractStream[Row] + + val sink = new TestingRetractSink + results.addSink(sink).setParallelism(1) + env.execute() + + val expected = List("1.03520274,12345.03520274865300000000,12.34567890123456700000,2.22222222") + assertEquals(expected, sink.getRetractResults) + } }