This is an automated email from the ASF dual-hosted git repository.
lincoln pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 4fceb461afd [FLINK-38740][table] Introduce Welford's online algorithm
for variance related functions to avoid catastrophic cancellation in naive
algorithm
4fceb461afd is described below
commit 4fceb461afd075bdad8269448a2c4cb15fee9ee8
Author: dylanhz <[email protected]>
AuthorDate: Mon Dec 29 15:20:39 2025 +0800
[FLINK-38740][table] Introduce Welford's online algorithm for variance
related functions to avoid catastrophic cancellation in naive algorithm
This closes #27325.
---
flink-python/pyflink/table/functions.py | 72 ++++++
.../functions/BuiltInFunctionDefinitions.java | 13 +
.../python/BuiltInPythonAggregateFunction.java | 3 +-
.../rel/rules/AggregateReduceFunctionsRule.java | 143 +++++------
.../functions/sql/FlinkSqlOperatorTable.java | 8 +
.../plan/nodes/exec/utils/CommonPythonUtil.java | 4 +
.../planner/plan/utils/AggFunctionFactory.scala | 17 ++
.../BuiltInAggregateFunctionTestBase.java | 17 +-
.../planner/functions/MathAggFunctionITCase.java | 246 +++++++++++++++++++
.../logical/AggregateReduceFunctionsRuleTest.java | 131 +++++++++-
.../batch/sql/agg/AggregateReduceGroupingTest.xml | 24 +-
.../planner/plan/batch/sql/agg/GroupWindowTest.xml | 41 ++--
.../logical/AggregateReduceFunctionsRuleTest.xml | 270 ++++++++++++++++++++-
.../logical/AggregateReduceGroupingRuleTest.xml | 21 +-
.../plan/stream/sql/agg/GroupWindowTest.xml | 6 +-
.../planner/plan/stream/table/GroupWindowTest.xml | 7 +-
.../batch/sql/agg/AggregateITCaseBase.scala | 14 ++
.../runtime/batch/table/AggregationITCase.scala | 8 +-
.../runtime/stream/sql/AggregateITCase.scala | 2 +-
.../functions/aggregate/WelfordM2AggFunction.java | 232 ++++++++++++++++++
20 files changed, 1138 insertions(+), 141 deletions(-)
diff --git a/flink-python/pyflink/table/functions.py
b/flink-python/pyflink/table/functions.py
index 378d6057a60..616e99cb2da 100644
--- a/flink-python/pyflink/table/functions.py
+++ b/flink-python/pyflink/table/functions.py
@@ -815,3 +815,75 @@ class SumWithRetractAggFunction(AggregateFunction):
if acc[0] is not None:
accumulator[0] += acc[0]
accumulator[1] += acc[1]
+
+
+@Internal()
+class WelfordM2AggFunction(AggregateFunction):
+ def create_accumulator(self):
+ # [n, mean , m2]
+ return [0, 0.0, 0.0]
+
+ def get_value(self, accumulator):
+ if accumulator[0] <= 0 or accumulator[2] <= 0:
+ return None
+ else:
+ return accumulator[2]
+
+ def accumulate(self, accumulator, *args):
+ if args[0] is None:
+ return
+
+ accumulator[0] += 1
+ if accumulator[0] <= 0:
+ return
+
+ val = float(args[0])
+ delta = val - accumulator[1]
+ accumulator[1] += delta / accumulator[0]
+ delta2 = val - accumulator[1]
+ accumulator[2] += delta * delta2
+
+ def retract(self, accumulator, *args):
+ if args[0] is None:
+ return
+
+ accumulator[0] -= 1
+ if accumulator[0] <= 0:
+ if accumulator[0] == 0:
+ accumulator[1] = 0.0
+ accumulator[2] = 0.0
+ return
+
+ val = float(args[0])
+ delta2 = val - accumulator[1]
+ accumulator[1] -= delta2 / accumulator[0]
+ delta = val - accumulator[1]
+ accumulator[2] -= delta * delta2
+
+ def merge(self, accumulator, accumulators):
+ negative_sum = 0
+ for acc in accumulators:
+ if acc[0] <= 0:
+ negative_sum += acc[0]
+ continue
+
+ if accumulator[0] == 0:
+ accumulator[0] = acc[0]
+ accumulator[1] = acc[1]
+ accumulator[2] = acc[2]
+ continue
+
+ new_n = accumulator[0] + acc[0]
+ delta_mean = acc[1] - accumulator[1]
+ new_mean = accumulator[1] + acc[0] / new_n * delta_mean
+ new_m2 = accumulator[2] + acc[2] + accumulator[0] * acc[
+ 0] / new_n * delta_mean * delta_mean
+
+ accumulator[0] = new_n
+ accumulator[1] = new_mean
+ accumulator[2] = new_m2
+
+ accumulator[0] += negative_sum
+ if accumulator[0] <= 0:
+ accumulator[1] = 0.0
+ accumulator[2] = 0.0
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
index cea16b08a8e..febd7ac1fa4 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/BuiltInFunctionDefinitions.java
@@ -912,6 +912,19 @@ public final class BuiltInFunctionDefinitions {
TypeStrategies.aggArg0(LogicalTypeMerging::findAvgAggType, true))
.build();
+ public static final BuiltInFunctionDefinition INTERNAL_WELFORD_M2 =
+ BuiltInFunctionDefinition.newBuilder()
+ .name("$WELFORD_M2$1")
+ .kind(AGGREGATE)
+ .inputTypeStrategy(
+ sequence(
+ Collections.singletonList("value"),
+
Collections.singletonList(logical(LogicalTypeFamily.NUMERIC))))
+ .outputTypeStrategy(explicit(DataTypes.DOUBLE()))
+ .runtimeProvided()
+ .internal()
+ .build();
+
public static final BuiltInFunctionDefinition COLLECT =
BuiltInFunctionDefinition.newBuilder()
.name("collect")
diff --git
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/BuiltInPythonAggregateFunction.java
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/BuiltInPythonAggregateFunction.java
index 34b672df329..637e1ae3ddb 100644
---
a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/BuiltInPythonAggregateFunction.java
+++
b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/functions/python/BuiltInPythonAggregateFunction.java
@@ -49,7 +49,8 @@ public enum BuiltInPythonAggregateFunction implements
PythonFunction {
FLOAT_SUM0("FloatSum0AggFunction"),
DECIMAL_SUM0("DecimalSum0AggFunction"),
SUM("SumAggFunction"),
- SUM_RETRACT("SumWithRetractAggFunction");
+ SUM_RETRACT("SumWithRetractAggFunction"),
+ WELFORD_M2("WelfordM2AggFunction");
private final byte[] payload;
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 3a870b12bb1..79819357250 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -16,6 +16,8 @@
*/
package org.apache.calcite.rel.rules;
+import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable;
+
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import org.apache.calcite.plan.RelOptCluster;
@@ -61,18 +63,16 @@ import java.util.function.Predicate;
* to simpler forms. This rule is copied to fix the correctness issue in Flink
before upgrading to
* the corresponding Calcite version. Flink modifications:
*
- * <p>Lines 561 ~ 571 to fix CALCITE-7192.
+ * <p>Lines 321 ~ 345, 529 ~ 632, adds Welford's online algorithm for stddev
calculation.
*
* <p>Rewrites:
*
* <ul>
* <li>AVG(x) → SUM(x) / COUNT(x)
- * <li>STDDEV_POP(x) → SQRT( (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
/ COUNT(x))
- * <li>STDDEV_SAMP(x) → SQRT( (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
/ CASE COUNT(x) WHEN
- * 1 THEN NULL ELSE COUNT(x) - 1 END)
- * <li>VAR_POP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x)
- * <li>VAR_SAMP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / CASE
COUNT(x) WHEN 1 THEN
- * NULL ELSE COUNT(x) - 1 END
+ * <li>STDDEV_POP(x) → SQRT(M2(x) / COUNT(x))
+ * <li>STDDEV_SAMP(x) → SQRT(M2(x) / CASE COUNT(x) WHEN 1 THEN NULL
ELSE COUNT(x) - 1 END)
+ * <li>VAR_POP(x) → M2(x) / COUNT(x)
+ * <li>VAR_SAMP(x) → M2(x) / CASE COUNT(x) WHEN 1 THEN NULL ELSE
COUNT(x) - 1 END
* <li>COVAR_POP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) /
REGR_COUNT(x, y)) /
* REGR_COUNT(x, y)
* <li>COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) /
REGR_COUNT(x, y)) / CASE
@@ -81,6 +81,12 @@ import java.util.function.Predicate;
* <li>REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x)
* </ul>
*
+ * <p>Helper functions:
+ *
+ * <ul>
+ * <li>M2(x) → sigma((x - mean(x)) * (x - mean(x)))
+ * </ul>
+ *
* <p>Since many of these rewrites introduce multiple occurrences of simpler
forms like {@code
* COUNT(x)}, the rule gathers common sub-expressions as it goes.
*
@@ -312,32 +318,31 @@ public class AggregateReduceFunctionsRule extends
RelRule<AggregateReduceFunctio
//noinspection SuspiciousNameCombination
return reduceRegrSzz(
oldAggRel, oldCall, newCalls, aggCallMapping,
inputExprs, x, x, y);
+ // FLINK MODIFICATION BEGIN
+ // M2(x) = sigma((x - mean(x)) * (x - mean(x)))
case STDDEV_POP:
// replace original STDDEV_POP(x) with
// SQRT(
- // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // M2(x)
// / COUNT(x))
- return reduceStddev(
- oldAggRel, oldCall, true, true, newCalls,
aggCallMapping, inputExprs);
+ return reduceStddev(oldAggRel, oldCall, true, true,
newCalls, aggCallMapping);
case STDDEV_SAMP:
- // replace original STDDEV_POP(x) with
+ // replace original STDDEV_SAMP(x) with
// SQRT(
- // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // M2(x)
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1
END)
- return reduceStddev(
- oldAggRel, oldCall, false, true, newCalls,
aggCallMapping, inputExprs);
+ return reduceStddev(oldAggRel, oldCall, false, true,
newCalls, aggCallMapping);
case VAR_POP:
// replace original VAR_POP(x) with
- // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // M2(x)
// / COUNT(x)
- return reduceStddev(
- oldAggRel, oldCall, true, false, newCalls,
aggCallMapping, inputExprs);
+ return reduceStddev(oldAggRel, oldCall, true, false,
newCalls, aggCallMapping);
case VAR_SAMP:
- // replace original VAR_POP(x) with
- // (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
+ // replace original VAR_SAMP(x) with
+ // M2(x)
// / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1
END
- return reduceStddev(
- oldAggRel, oldCall, false, false, newCalls,
aggCallMapping, inputExprs);
+ return reduceStddev(oldAggRel, oldCall, false, false,
newCalls, aggCallMapping);
+ // FLINK MODIFICATION END
default:
throw Util.unexpected(kind);
}
@@ -521,70 +526,41 @@ public class AggregateReduceFunctionsRule extends
RelRule<AggregateReduceFunctio
sumZeroRef);
}
+ // FLINK MODIFICATION BEGIN
private static RexNode reduceStddev(
Aggregate oldAggRel,
AggregateCall oldCall,
boolean biased,
boolean sqrt,
List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping,
- List<RexNode> inputExprs) {
- // stddev_pop(x) ==>
- // power(
- // (sum(x * x) - sum(x) * sum(x) / count(x))
- // / count(x),
- // .5)
- //
- // stddev_samp(x) ==>
- // power(
- // (sum(x * x) - sum(x) * sum(x) / count(x))
- // / nullif(count(x) - 1, 0),
- // .5)
+ Map<AggregateCall, RexNode> aggCallMapping) {
final int nGroups = oldAggRel.getGroupCount();
final RelOptCluster cluster = oldAggRel.getCluster();
final RexBuilder rexBuilder = cluster.getRexBuilder();
- final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
-
- assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
- final IntPredicate fieldIsNullable =
oldAggRel.getInput()::fieldIsNullable;
- final RelDataType oldCallType =
- typeFactory.createTypeWithNullability(
- oldCall.getType(), fieldIsNullable.test(argOrdinal));
-
- final RexNode argRef = rexBuilder.ensureType(oldCallType,
inputExprs.get(argOrdinal), true);
- final RexNode argSquared =
- rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argRef,
argRef);
- final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
+ // The Welford's online algorithm is introduced to avoid catastrophic
cancellation.
+ //
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
- // FLINK MODIFICATION BEGIN
- final AggregateCall sumArgSquaredAggCall =
- createAggregateCallWithBinding(
- typeFactory,
- SqlStdOperatorTable.SUM,
- argSquared.getType(),
- oldAggRel,
- oldCall,
- argSquaredOrdinal,
- oldCall.filterArg);
- // FLINK MODIFICATION END
-
- final RexNode sumArgSquared =
- rexBuilder.addAggCall(
- sumArgSquaredAggCall,
- nGroups,
- newCalls,
- aggCallMapping,
- oldAggRel.getInput()::fieldIsNullable);
+ // Welford m2 formula ==>
+ // n_new = n_old + 1
+ // delta = x - mean_old
+ // mean_new = mean_old + delta / n_new
+ // delta2 = x - mean_new
+ // m2_new = m2_old + delta * delta2
+ //
+ // stddev_pop(x) ==> power(m2(x) / count(x), .5)
+ // stddev_samp(x) ==> power(m2(x) / nullif(count(x) - 1, 0), .5)
+ // var_pop(x) ==> m2(x) / count(x)
+ // var_samp(x) ==> m2(x) / nullif(count(x) - 1, 0)
- final AggregateCall sumArgAggCall =
+ final AggregateCall m2AggCall =
AggregateCall.create(
- SqlStdOperatorTable.SUM,
+ FlinkSqlOperatorTable.INTERNAL_WELFORD_M2,
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
- oldCall.rexList,
+ ImmutableList.of(),
ImmutableIntList.of(argOrdinal),
oldCall.filterArg,
oldCall.distinctKeys,
@@ -594,16 +570,13 @@ public class AggregateReduceFunctionsRule extends
RelRule<AggregateReduceFunctio
null,
null);
- final RexNode sumArg =
+ final RexNode m2Ref =
rexBuilder.addAggCall(
- sumArgAggCall,
+ m2AggCall,
nGroups,
newCalls,
aggCallMapping,
oldAggRel.getInput()::fieldIsNullable);
- final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg,
true);
- final RexNode sumSquaredArg =
- rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumArgCast,
sumArgCast);
final AggregateCall countArgAggCall =
AggregateCall.create(
@@ -611,7 +584,7 @@ public class AggregateReduceFunctionsRule extends
RelRule<AggregateReduceFunctio
oldCall.isDistinct(),
oldCall.isApproximate(),
oldCall.ignoreNulls(),
- oldCall.rexList,
+ ImmutableList.of(),
oldCall.getArgList(),
oldCall.filterArg,
oldCall.distinctKeys,
@@ -629,19 +602,35 @@ public class AggregateReduceFunctionsRule extends
RelRule<AggregateReduceFunctio
aggCallMapping,
oldAggRel.getInput()::fieldIsNullable);
- final RexNode div = divide(biased, rexBuilder, sumArgSquared,
sumSquaredArg, countArg);
+ final RexNode denominator;
+ if (biased) {
+ denominator = countArg;
+ } else {
+ final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
+ final RexNode nul = rexBuilder.makeNullLiteral(countArg.getType());
+ final RexNode countMinusOne =
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg,
one);
+ final RexNode countEqOne =
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg,
one);
+ denominator =
+ rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne,
nul, countMinusOne);
+ }
+
+ RexNode variance = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
m2Ref, denominator);
final RexNode result;
if (sqrt) {
final RexNode half = rexBuilder.makeExactLiteral(new
BigDecimal("0.5"));
- result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, div, half);
+ result = rexBuilder.makeCall(SqlStdOperatorTable.POWER, variance,
half);
} else {
- result = div;
+ result = variance;
}
return rexBuilder.makeCast(oldCall.getType(), result);
}
+ // FLINK MODIFICATION END
+
private static RexNode reduceAggCallByGrouping(Aggregate oldAggRel,
AggregateCall oldCall) {
final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
index 4469b376420..1328c863196 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java
@@ -19,6 +19,7 @@
package org.apache.flink.table.planner.functions.sql;
import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.functions.BuiltInFunctionDefinitions;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import
org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
import
org.apache.flink.table.planner.functions.sql.ml.SqlMLEvaluateTableFunction;
@@ -1152,6 +1153,13 @@ public class FlinkSqlOperatorTable extends
ReflectiveSqlOperatorTable {
public static final SqlAggFunction VARIANCE = SqlStdOperatorTable.VARIANCE;
public static final SqlAggFunction VAR_POP = SqlStdOperatorTable.VAR_POP;
public static final SqlAggFunction VAR_SAMP = SqlStdOperatorTable.VAR_SAMP;
+ public static final SqlAggFunction INTERNAL_WELFORD_M2 =
+ SqlBasicAggFunction.create(
+
BuiltInFunctionDefinitions.INTERNAL_WELFORD_M2.getName(),
+ SqlKind.OTHER_FUNCTION,
+
ReturnTypes.DOUBLE.andThen(SqlTypeTransforms.FORCE_NULLABLE),
+ OperandTypes.NUMERIC)
+ .withFunctionType(SqlFunctionCategory.SYSTEM);
public static final SqlAggFunction SINGLE_VALUE =
SqlStdOperatorTable.SINGLE_VALUE;
public static final SqlAggFunction APPROX_COUNT_DISTINCT =
SqlStdOperatorTable.APPROX_COUNT_DISTINCT;
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
index 43fc9a9a92e..ae9a375ba64 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/utils/CommonPythonUtil.java
@@ -60,6 +60,7 @@ import
org.apache.flink.table.runtime.functions.aggregate.ListAggWithRetractAggF
import
org.apache.flink.table.runtime.functions.aggregate.ListAggWsWithRetractAggFunction;
import
org.apache.flink.table.runtime.functions.aggregate.MaxWithRetractAggFunction;
import
org.apache.flink.table.runtime.functions.aggregate.MinWithRetractAggFunction;
+import org.apache.flink.table.runtime.functions.aggregate.WelfordM2AggFunction;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.FieldsDataType;
import org.apache.flink.table.types.inference.TypeInference;
@@ -535,6 +536,9 @@ public class CommonPythonUtil {
if (javaBuiltInAggregateFunction instanceof SumWithRetractAggFunction)
{
return BuiltInPythonAggregateFunction.SUM_RETRACT;
}
+ if (javaBuiltInAggregateFunction instanceof WelfordM2AggFunction) {
+ return BuiltInPythonAggregateFunction.WELFORD_M2;
+ }
throw new TableException(
"Aggregate function "
+ javaBuiltInAggregateFunction
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
index 99560f83478..a157ea591ec 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/AggFunctionFactory.scala
@@ -142,6 +142,10 @@ class AggFunctionFactory(
case _: SqlListAggFunction if call.getArgList.size() == 2 =>
createListAggWsFunction(argTypes, index)
+ case a: SqlBasicAggFunction
+ if a.getName ==
BuiltInFunctionDefinitions.INTERNAL_WELFORD_M2.getName =>
+ createWelfordM2AggFunction(argTypes)
+
// TODO supports SqlCardinalityCountAggFunction
case a: SqlAggFunction if a.getKind == SqlKind.COLLECT =>
@@ -605,6 +609,19 @@ class AggFunctionFactory(
}
}
+ private def createWelfordM2AggFunction(argTypes: Array[LogicalType]):
UserDefinedFunction = {
+ argTypes(0).getTypeRoot match {
+ case TINYINT | SMALLINT | INTEGER | BIGINT | FLOAT | DOUBLE =>
+ new WelfordM2AggFunction.NumberFunction(argTypes(0))
+ case DECIMAL =>
+ new WelfordM2AggFunction.DecimalFunction(argTypes(0))
+ case _ =>
+ throw new TableException(
+ s"${BuiltInFunctionDefinitions.INTERNAL_WELFORD_M2.getName}
aggregate function does not support type: '${argTypes(
+ 0).getTypeRoot}'.")
+ }
+ }
+
private def createCollectAggFunction(argTypes: Array[LogicalType]):
UserDefinedFunction = {
new CollectAggFunction(argTypes(0))
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java
index 56dd35e5149..4c5f85a057c 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/BuiltInAggregateFunctionTestBase.java
@@ -20,6 +20,7 @@ package org.apache.flink.table.planner.functions;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.client.program.MiniClusterClient;
+import org.apache.flink.configuration.ConfigOption;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.connector.datagen.source.DataGeneratorSource;
@@ -200,6 +201,7 @@ abstract class BuiltInAggregateFunctionTestBase {
private @Nullable String description;
+ private final Configuration configuration = new Configuration();
private DataType sourceRowType;
private List<Row> sourceRows;
@@ -220,6 +222,16 @@ abstract class BuiltInAggregateFunctionTestBase {
return this;
}
+ <T> TestSpec withConfiguration(ConfigOption<T> option, T value) {
+ this.configuration.set(option, value);
+ return this;
+ }
+
+ TestSpec withConfiguration(String key, String value) {
+ this.configuration.setString(key, value);
+ return this;
+ }
+
TestSpec withSource(DataType sourceRowType, List<Row> sourceRows) {
this.sourceRowType = sourceRowType;
this.sourceRows = sourceRows;
@@ -324,13 +336,12 @@ abstract class BuiltInAggregateFunctionTestBase {
private TestCaseWithClusterClient createTestItemExecutable(
TestItem testItem, String stateBackend) {
return (clusterClient) -> {
- Configuration conf = new Configuration();
- conf.set(StateBackendOptions.STATE_BACKEND, stateBackend);
+ configuration.set(StateBackendOptions.STATE_BACKEND,
stateBackend);
final TableEnvironment tEnv =
TableEnvironment.create(
EnvironmentSettings.newInstance()
.inStreamingMode()
- .withConfiguration(conf)
+ .withConfiguration(configuration)
.build());
final Table sourceTable = asTable(tEnv, sourceRowType,
sourceRows);
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MathAggFunctionITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MathAggFunctionITCase.java
new file mode 100644
index 00000000000..0d6e09c98d2
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/MathAggFunctionITCase.java
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.functions;
+
+import org.apache.flink.table.api.config.ExecutionConfigOptions;
+import org.apache.flink.types.Row;
+
+import java.math.BigDecimal;
+import java.time.Duration;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Stream;
+
+import static org.apache.flink.table.api.DataTypes.BIGINT;
+import static org.apache.flink.table.api.DataTypes.BOOLEAN;
+import static org.apache.flink.table.api.DataTypes.DECIMAL;
+import static org.apache.flink.table.api.DataTypes.DOUBLE;
+import static org.apache.flink.table.api.DataTypes.FLOAT;
+import static org.apache.flink.table.api.DataTypes.INT;
+import static org.apache.flink.table.api.DataTypes.ROW;
+import static org.apache.flink.table.api.DataTypes.SMALLINT;
+import static org.apache.flink.table.api.DataTypes.STRING;
+import static org.apache.flink.table.api.DataTypes.TINYINT;
+import static org.apache.flink.types.RowKind.DELETE;
+import static org.apache.flink.types.RowKind.INSERT;
+import static org.apache.flink.types.RowKind.UPDATE_AFTER;
+import static org.apache.flink.types.RowKind.UPDATE_BEFORE;
+
+/** Tests for built-in math aggregation functions. */
+class MathAggFunctionITCase extends BuiltInAggregateFunctionTestBase {
+
+ @Override
+ Stream<TestSpec> getTestCaseSpecs() {
+ return Stream.of(varianceRelatedTestCases()).flatMap(s -> s);
+ }
+
+ private Stream<TestSpec> varianceRelatedTestCases() {
+ List<Row> batchData = new ArrayList<>();
+ for (int i = 0; i < 50; i++) {
+ batchData.add(Row.ofKind(INSERT, "A", (double) i));
+ batchData.add(Row.ofKind(INSERT, "B", null));
+ batchData.add(Row.ofKind(INSERT, "B", null));
+ batchData.add(Row.ofKind(INSERT, "B", (double) i));
+ batchData.add(Row.ofKind(DELETE, "B", (double) i));
+ batchData.add(Row.ofKind(DELETE, "B", (double) i + 50));
+ }
+
+ return Stream.of(
+ TestSpec.forExpression("failed case of FLINK-38740")
+ .withSource(
+ ROW(STRING(), DOUBLE()),
+ Arrays.asList(
+ Row.ofKind(INSERT, "A", 0.27),
+ Row.ofKind(INSERT, "A", 0.27),
+ Row.ofKind(INSERT, "A", 0.27)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "STDDEV_POP(f1),
STDDEV_SAMP(f1), VAR_POP(f1) < 0, VAR_SAMP(f1) < 0 "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), DOUBLE(), DOUBLE(), BOOLEAN(),
BOOLEAN()),
+ List.of(Row.of("A", 0.0, 0.0, false, false))),
+ TestSpec.forExpression("integer numeric")
+ .withSource(
+ ROW(STRING(), TINYINT(), SMALLINT(), INT(),
BIGINT()),
+ Arrays.asList(
+ Row.of("A", (byte) 0x1, (short) 0x10,
0x100, 0x1000L),
+ Row.of("A", (byte) 0x2, (short) 0x20,
0x200, 0x2000L),
+ Row.of("A", (byte) 0x3, (short) 0x30,
0x300, 0x3000L)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "STDDEV_POP(f1),
STDDEV_SAMP(f1), VAR_POP(f1), VAR_SAMP(f1) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), TINYINT(), TINYINT(), TINYINT(),
TINYINT()),
+ List.of(Row.of("A", (byte) 0, (byte) 1, (byte)
0, (byte) 1)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "STDDEV_POP(f2),
STDDEV_SAMP(f2), VAR_POP(f2), VAR_SAMP(f2) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), SMALLINT(), SMALLINT(),
SMALLINT(), SMALLINT()),
+ List.of(
+ Row.of(
+ "A",
+ (short) 13,
+ (short) 16,
+ (short) 170,
+ (short) 256)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "STDDEV_POP(f3),
STDDEV_SAMP(f3), VAR_POP(f3), VAR_SAMP(f3) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), INT(), INT(), INT(), INT()),
+ List.of(Row.of("A", 209, 256, 43690, 65536)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "STDDEV_POP(f4),
STDDEV_SAMP(f4), VAR_POP(f4), VAR_SAMP(f4) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), BIGINT(), BIGINT(), BIGINT(),
BIGINT()),
+ List.of(Row.of("A", 3344L, 4096L, 11184810L,
16777216L))),
+ TestSpec.forExpression("approximate numeric")
+ .withSource(
+ ROW(STRING(), FLOAT(), DOUBLE()),
+ Arrays.asList(
+ Row.of("A", 0.1f, 1.0d),
+ Row.of("A", 0.2f, 1.5d),
+ Row.of("A", 0.3f, 2.0d)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "ROUND(STDDEV_POP(f1), 6),
ROUND(STDDEV_SAMP(f1), 6), "
+ + "ROUND(VAR_POP(f1), 6),
ROUND(VAR_SAMP(f1), 6) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), FLOAT(), FLOAT(), FLOAT(),
FLOAT()),
+ List.of(Row.of("A", 0.081650f, 0.100000f,
0.006667f, 0.010000f)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "ROUND(STDDEV_POP(f2), 6),
ROUND(STDDEV_SAMP(f2), 6), "
+ + "ROUND(VAR_POP(f2), 6),
ROUND(VAR_SAMP(f2), 6) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), DOUBLE(), DOUBLE(), DOUBLE(),
DOUBLE()),
+ List.of(Row.of("A", 0.408248, 0.500000,
0.166667, 0.250000))),
+ TestSpec.forExpression("decimal")
+ .withSource(
+ ROW(STRING(), DECIMAL(10, 2)),
+ Arrays.asList(
+ Row.of("A", BigDecimal.valueOf(0.27)),
+ Row.of("A", BigDecimal.valueOf(0.28)),
+ Row.of("A", BigDecimal.valueOf(0.29))))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "ROUND(STDDEV_POP(f1), 6),
ROUND(STDDEV_SAMP(f1), 6), "
+ + "ROUND(VAR_POP(f1), 6),
ROUND(VAR_SAMP(f1), 6) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(
+ STRING(),
+ DECIMAL(38, 6),
+ DECIMAL(38, 6),
+ DECIMAL(38, 6),
+ DECIMAL(38, 6)),
+ List.of(
+ Row.of(
+ "A",
+ BigDecimal.valueOf(8165L, 6),
+ BigDecimal.valueOf(10000L, 6),
+ BigDecimal.valueOf(67L, 6),
+ BigDecimal.valueOf(100L, 6)))),
+ TestSpec.forExpression("retract")
+ .withSource(
+ ROW(STRING(), DOUBLE()),
+ Arrays.asList(
+ Row.ofKind(INSERT, "A", null),
+ Row.ofKind(INSERT, "B", 0.27),
+ Row.ofKind(INSERT, "B", null),
+ Row.ofKind(DELETE, "B", 0.27),
+ Row.ofKind(INSERT, "C", 0.27),
+ Row.ofKind(INSERT, "C", 0.28),
+ Row.ofKind(INSERT, "C", null),
+ Row.ofKind(INSERT, "C", null),
+ Row.ofKind(INSERT, "C", null),
+ Row.ofKind(DELETE, "C", null),
+ Row.ofKind(UPDATE_BEFORE, "C", 0.27),
+ Row.ofKind(UPDATE_AFTER, "C", 0.30),
+ Row.ofKind(DELETE, "C", 0.27),
+ Row.ofKind(DELETE, "C", 0.30),
+ Row.ofKind(DELETE, "C", 0.28),
+ Row.ofKind(INSERT, "C", 0.27),
+ Row.ofKind(INSERT, "C", 0.50),
+ Row.ofKind(INSERT, "D", 0.27),
+ Row.ofKind(INSERT, "D", 0.28),
+ Row.ofKind(INSERT, "D", 0.29),
+ Row.ofKind(DELETE, "D", 0.28)))
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "ROUND(STDDEV_POP(f1), 6),
ROUND(STDDEV_SAMP(f1), 6), "
+ + "ROUND(VAR_POP(f1), 6),
ROUND(VAR_SAMP(f1), 6) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), DOUBLE(), DOUBLE(), DOUBLE(),
DOUBLE()),
+ List.of(
+ Row.of("A", null, null, null, null),
+ Row.of("B", null, null, null, null),
+ Row.of("C", 0.000000, null, 0.000000,
null),
+ Row.of("D", 0.010000, 0.014142,
0.000100, 0.000200))),
+ TestSpec.forExpression("merge")
+ .withConfiguration(
+
ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ENABLED, Boolean.TRUE)
+
.withConfiguration(ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_SIZE, 10L)
+ .withConfiguration(
+
ExecutionConfigOptions.TABLE_EXEC_MINIBATCH_ALLOW_LATENCY,
+ Duration.ofMillis(100))
+ .withSource(ROW(STRING(), DOUBLE()), batchData)
+ .testSqlResult(
+ source ->
+ "SELECT f0, "
+ + "ROUND(STDDEV_POP(f1), 6),
ROUND(STDDEV_SAMP(f1), 6), "
+ + "ROUND(VAR_POP(f1), 6),
ROUND(VAR_SAMP(f1), 6) "
+ + "FROM "
+ + source
+ + " GROUP BY f0",
+ ROW(STRING(), DOUBLE(), DOUBLE(), DOUBLE(),
DOUBLE()),
+ List.of(
+ Row.of("A", 14.430870, 14.577380,
208.250000, 212.500000),
+ Row.of("B", null, null, null, null))));
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
index 209e0e43b56..7a0f3721a30 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.java
@@ -34,7 +34,11 @@ import org.apache.calcite.tools.RuleSets;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
-/** Test for {@link AggregateReduceFunctionsRule}. */
+/**
+ * Test for {@link CoreRules#AGGREGATE_REDUCE_FUNCTIONS}.
+ *
+ * <p>For now COVAR_POP/COVAR_SAMP/REGR_SXX/REGR_SYY are unsupported yet.
+ */
public class AggregateReduceFunctionsRuleTest extends TableTestBase {
private BatchTableTestUtil util;
@@ -65,6 +69,19 @@ public class AggregateReduceFunctionsRuleTest extends
TableTestBase {
+ " 'connector' = 'values',\n"
+ " 'bounded' = 'true'\n"
+ ")");
+
+ util.tableEnv()
+ .executeSql(
+ "CREATE TABLE src2 (\n"
+ + " id INT,\n"
+ + " category VARCHAR,\n"
+ + " x DOUBLE,\n"
+ + " y DOUBLE,\n"
+ + " z DECIMAL(10,2)\n"
+ + ") WITH (\n"
+ + " 'connector' = 'values',\n"
+ + " 'bounded' = 'true'\n"
+ + ")");
}
@Test
@@ -78,4 +95,116 @@ public class AggregateReduceFunctionsRuleTest extends
TableTestBase {
+ "AVG(b) FILTER (WHERE b > 50)\n"
+ "FROM src GROUP BY a");
}
+
+ @Test
+ void testVarianceAndStandardDeviation() {
+ // Test variance and standard deviation reductions without filters
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "VAR_POP(x), \n"
+ + "VAR_SAMP(x), \n"
+ + "STDDEV_POP(x), \n"
+ + "STDDEV_SAMP(x) \n"
+ + "FROM src2 GROUP BY category");
+ }
+
+ @Test
+ void testMixedAggregatesWithDistinct() {
+ // Test combinations of different aggregate functions with DISTINCT
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(DISTINCT x), \n"
+ + "VAR_POP(DISTINCT y), \n"
+ + "STDDEV_SAMP(DISTINCT z) \n"
+ + "FROM src2 GROUP BY category");
+ }
+
+ @Test
+ void testComplexExpressionsInAggregates() {
+ // Test aggregates with complex expressions as arguments
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(x + y), \n"
+ + "VAR_POP(x * 2), \n"
+ + "STDDEV_SAMP(CASE WHEN x > 0 THEN x ELSE 0 END) \n"
+ + "FROM src2 GROUP BY category");
+ }
+
+ @Test
+ void testMultipleGroupingColumns() {
+ // Test aggregate reductions with multiple grouping columns
+ util.verifyRelPlan(
+ "SELECT category, id, \n"
+ + "AVG(x), \n"
+ + "VAR_SAMP(y), \n"
+ + "VAR_POP(z) \n"
+ + "FROM src2 GROUP BY category, id");
+ }
+
+ @Test
+ void testAggregatesWithOrderBy() {
+ // Test aggregate reductions with ORDER BY clause
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(x) as avg_x, \n"
+ + "STDDEV_POP(y) as stddev_y \n"
+ + "FROM src2 \n"
+ + "GROUP BY category \n"
+ + "ORDER BY avg_x DESC");
+ }
+
+ @Test
+ void testAggregatesWithHaving() {
+ // Test aggregate reductions with HAVING clause
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(x) as avg_x, \n"
+ + "VAR_POP(y) as var_y \n"
+ + "FROM src2 \n"
+ + "GROUP BY category \n"
+ + "HAVING AVG(x) > 50 AND VAR_POP(y) < 100");
+ }
+
+ @Test
+ void testDifferentDataTypes() {
+ // Test aggregate reductions with different data types
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(CAST(id AS DOUBLE)), \n"
+ + "VAR_POP(CAST(x AS DECIMAL(20,4))), \n"
+ + "STDDEV_SAMP(z) \n"
+ + "FROM src2 GROUP BY category");
+ }
+
+ @Test
+ void testEmptyGroupBy() {
+ // Test aggregate reductions without GROUP BY (global aggregates)
+ util.verifyRelPlan(
+ "SELECT \n" + "AVG(x), \n" + "VAR_POP(y), \n" +
"STDDEV_SAMP(z) \n" + "FROM src2");
+ }
+
+ @Test
+ void testComplexFilterConditions() {
+ // Test aggregates with complex filter conditions
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "AVG(x) FILTER (WHERE x > y AND category LIKE 'A%'),
\n"
+ + "STDDEV_POP(x) FILTER (WHERE x + y > 100), \n"
+ + "VAR_SAMP(y) FILTER (WHERE MOD(id, 2) = 0) \n"
+ + "FROM src2 GROUP BY category");
+ }
+
+ @Test
+ void testAllSupportedFunctionsInSingleQuery() {
+ // Comprehensive test with all supported aggregate functions
+ util.verifyRelPlan(
+ "SELECT category, \n"
+ + "SUM(x), \n"
+ + "AVG(x), \n"
+ + "VAR_POP(x), \n"
+ + "VAR_SAMP(x), \n"
+ + "STDDEV_POP(x), \n"
+ + "STDDEV_SAMP(x) \n"
+ + "FROM src2 GROUP BY category");
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
index ffbce7d6679..195d2f7e53f 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/AggregateReduceGroupingTest.xml
@@ -427,10 +427,10 @@ Calc(select=[a4, c4, s, EXPR$3])
+- HashAggregate(isMerge=[true], groupBy=[a4, s], auxGrouping=[c4],
select=[a4, s, c4, Final_COUNT(count$0) AS EXPR$3])
+- Exchange(distribution=[hash[a4, s]])
+- LocalHashAggregate(groupBy=[a4, s], auxGrouping=[c4], select=[a4, s,
c4, Partial_COUNT(b4) AS count$0])
- +- Calc(select=[a4, c4, w$start AS s, CAST((($f2 - (($f3 * $f3) /
$f4)) / $f4) AS INTEGER) AS b4])
- +- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
- +- Exchange(distribution=[keep_input_as_is[hash[a4]]])
- +- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
+ +- Calc(select=[a4, c4, w$start AS s, CAST(($f2 / $f3) AS INTEGER) AS
b4])
+ +- SortWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, $WELFORD_M2$1(b4) AS $f2, COUNT(b4) AS $f3])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[a4 ASC, d4 ASC])
+- Exchange(distribution=[hash[a4]])
+- LegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
@@ -456,10 +456,10 @@ Calc(select=[a4, c4, e, EXPR$3])
+- HashAggregate(isMerge=[true], groupBy=[a4, e], auxGrouping=[c4],
select=[a4, e, c4, Final_COUNT(count$0) AS EXPR$3])
+- Exchange(distribution=[hash[a4, e]])
+- LocalHashAggregate(groupBy=[a4, e], auxGrouping=[c4], select=[a4, e,
c4, Partial_COUNT(b4) AS count$0])
- +- Calc(select=[a4, c4, w$end AS e, CAST((($f2 - (($f3 * $f3) / $f4))
/ $f4) AS INTEGER) AS b4])
- +- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
- +- Exchange(distribution=[keep_input_as_is[hash[a4]]])
- +- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
+ +- Calc(select=[a4, c4, w$end AS e, CAST(($f2 / $f3) AS INTEGER) AS
b4])
+ +- SortWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, $WELFORD_M2$1(b4) AS $f2, COUNT(b4) AS $f3])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[a4 ASC, d4 ASC])
+- Exchange(distribution=[hash[a4]])
+- LegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
@@ -484,10 +484,10 @@ LogicalAggregate(group=[{0, 1, 2}], EXPR$3=[COUNT()])
HashAggregate(isMerge=[true], groupBy=[a4, b4], auxGrouping=[c4], select=[a4,
b4, c4, Final_COUNT(count1$0) AS EXPR$3])
+- Exchange(distribution=[hash[a4, b4]])
+- LocalHashAggregate(groupBy=[a4, b4], auxGrouping=[c4], select=[a4, b4,
c4, Partial_COUNT(*) AS count1$0])
- +- Calc(select=[a4, CAST((($f2 - (($f3 * $f3) / $f4)) / $f4) AS INTEGER)
AS b4, c4])
- +- HashWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, SUM($f4) AS $f2, SUM(b4) AS $f3, COUNT(b4) AS $f4])
- +- Exchange(distribution=[keep_input_as_is[hash[a4]]])
- +- Calc(select=[a4, c4, d4, b4, (b4 * b4) AS $f4])
+ +- Calc(select=[a4, CAST(($f2 / $f3) AS INTEGER) AS b4, c4])
+ +- SortWindowAggregate(groupBy=[a4], auxGrouping=[c4],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime], select=[a4, c4, $WELFORD_M2$1(b4) AS $f2, COUNT(b4) AS $f3])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[a4 ASC, d4 ASC])
+- Exchange(distribution=[hash[a4]])
+- LegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupWindowTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupWindowTest.xml
index cc694a0e4de..22e42ff8e79 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupWindowTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/batch/sql/agg/GroupWindowTest.xml
@@ -39,12 +39,15 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3],
EXPR$3=[$4], EXPR$4=[TUMBL
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
-Calc(select=[CAST((($f0 - (($f1 * $f1) / $f2)) / $f2) AS INTEGER) AS EXPR$0,
CAST((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 - 1))) AS
INTEGER) AS EXPR$1, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / $f2), 0.5) AS
INTEGER) AS EXPR$2, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1),
null:BIGINT, ($f2 - 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end
AS EXPR$5])
-+- HashWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS $f0,
Final_SUM(sum$1) AS $f1, Final_COUNT(count$2) AS $f2])
- +- Exchange(distribution=[single])
- +- LocalHashWindowAggregate(window=[TumblingGroupWindow('w$, ts,
900000)], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM($f2) AS
sum$0, Partial_SUM(b) AS sum$1, Partial_COUNT(b) AS count$2])
- +- Calc(select=[ts, b, (b * b) AS $f2])
- +- TableSourceScan(table=[[default_catalog, default_database,
MyTable1]], fields=[ts, a, b, c])
+Calc(select=[CAST(($f0 / $f1) AS INTEGER) AS EXPR$0, CAST(($f0 / CASE(($f1 =
1), null:BIGINT, ($f1 - 1))) AS INTEGER) AS EXPR$1, CAST(POWER(($f0 / $f1),
0.5) AS INTEGER) AS EXPR$2, CAST(POWER(($f0 / CASE(($f1 = 1), null:BIGINT, ($f1
- 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
++- SortWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[Final_$WELFORD_M2$1($f0) AS
$f0, Final_COUNT(count$0) AS $f1])
+ +- Sort(orderBy=[assignedWindow$ ASC])
+ +- Exchange(distribution=[single])
+ +- LocalSortWindowAggregate(window=[TumblingGroupWindow('w$, ts,
900000)], properties=[w$start, w$end, w$rowtime],
select=[Partial_$WELFORD_M2$1(b) AS $f0, Partial_COUNT(b) AS count$0])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[ts ASC])
+ +- Calc(select=[ts, b])
+ +- TableSourceScan(table=[[default_catalog,
default_database, MyTable1]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -71,11 +74,12 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3],
EXPR$3=[$4], EXPR$4=[TUMBL
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
-Calc(select=[CAST((($f0 - (($f1 * $f1) / $f2)) / $f2) AS INTEGER) AS EXPR$0,
CAST((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 - 1))) AS
INTEGER) AS EXPR$1, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / $f2), 0.5) AS
INTEGER) AS EXPR$2, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1),
null:BIGINT, ($f2 - 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end
AS EXPR$5])
-+- HashWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[SUM($f2) AS $f0, SUM(b) AS $f1,
COUNT(b) AS $f2])
- +- Exchange(distribution=[single])
- +- Calc(select=[ts, b, (b * b) AS $f2])
- +- TableSourceScan(table=[[default_catalog, default_database,
MyTable1]], fields=[ts, a, b, c])
+Calc(select=[CAST(($f0 / $f1) AS INTEGER) AS EXPR$0, CAST(($f0 / CASE(($f1 =
1), null:BIGINT, ($f1 - 1))) AS INTEGER) AS EXPR$1, CAST(POWER(($f0 / $f1),
0.5) AS INTEGER) AS EXPR$2, CAST(POWER(($f0 / CASE(($f1 = 1), null:BIGINT, ($f1
- 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
++- SortWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[$WELFORD_M2$1(b) AS $f0,
COUNT(b) AS $f1])
+ +- Sort(orderBy=[ts ASC])
+ +- Exchange(distribution=[single])
+ +- Calc(select=[ts, b])
+ +- TableSourceScan(table=[[default_catalog, default_database,
MyTable1]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
@@ -102,12 +106,15 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3],
EXPR$3=[$4], EXPR$4=[TUMBL
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
-Calc(select=[CAST((($f0 - (($f1 * $f1) / $f2)) / $f2) AS INTEGER) AS EXPR$0,
CAST((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 - 1))) AS
INTEGER) AS EXPR$1, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / $f2), 0.5) AS
INTEGER) AS EXPR$2, CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1),
null:BIGINT, ($f2 - 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end
AS EXPR$5])
-+- HashWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[Final_SUM(sum$0) AS $f0,
Final_SUM(sum$1) AS $f1, Final_COUNT(count$2) AS $f2])
- +- Exchange(distribution=[single])
- +- LocalHashWindowAggregate(window=[TumblingGroupWindow('w$, ts,
900000)], properties=[w$start, w$end, w$rowtime], select=[Partial_SUM($f2) AS
sum$0, Partial_SUM(b) AS sum$1, Partial_COUNT(b) AS count$2])
- +- Calc(select=[ts, b, (b * b) AS $f2])
- +- TableSourceScan(table=[[default_catalog, default_database,
MyTable1]], fields=[ts, a, b, c])
+Calc(select=[CAST(($f0 / $f1) AS INTEGER) AS EXPR$0, CAST(($f0 / CASE(($f1 =
1), null:BIGINT, ($f1 - 1))) AS INTEGER) AS EXPR$1, CAST(POWER(($f0 / $f1),
0.5) AS INTEGER) AS EXPR$2, CAST(POWER(($f0 / CASE(($f1 = 1), null:BIGINT, ($f1
- 1))), 0.5) AS INTEGER) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
++- SortWindowAggregate(window=[TumblingGroupWindow('w$, ts, 900000)],
properties=[w$start, w$end, w$rowtime], select=[Final_$WELFORD_M2$1($f0) AS
$f0, Final_COUNT(count$0) AS $f1])
+ +- Sort(orderBy=[assignedWindow$ ASC])
+ +- Exchange(distribution=[single])
+ +- LocalSortWindowAggregate(window=[TumblingGroupWindow('w$, ts,
900000)], properties=[w$start, w$end, w$rowtime],
select=[Partial_$WELFORD_M2$1(b) AS $f0, Partial_COUNT(b) AS count$0])
+ +- Exchange(distribution=[forward])
+ +- Sort(orderBy=[ts ASC])
+ +- Calc(select=[ts, b])
+ +- TableSourceScan(table=[[default_catalog,
default_database, MyTable1]], fields=[ts, a, b, c])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
index c2fceebec56..5ef4da779a1 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceFunctionsRuleTest.xml
@@ -16,6 +16,265 @@ See the License for the specific language governing
permissions and
limitations under the License.
-->
<Root>
+ <TestCase name="testAggregatesWithHaving">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(x) as avg_x,
+VAR_POP(y) as var_y
+FROM src2
+GROUP BY category
+HAVING AVG(x) > 50 AND VAR_POP(y) < 100]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalFilter(condition=[AND(>($1, 50), <($2, 100))])
++- LogicalAggregate(group=[{0}], avg_x=[AVG($1)], var_y=[VAR_POP($2)])
+ +- LogicalProject(category=[$1], x=[$2], y=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalFilter(condition=[AND(>($1, 50), <($2, 100))])
++- LogicalProject(category=[$0], avg_x=[/($1, $2)], var_y=[/($3, $4)])
+ +- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)],
agg#2=[$WELFORD_M2$1($2)], agg#3=[COUNT($2)])
+ +- LogicalProject(category=[$1], x=[$2], y=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testAggregatesWithOrderBy">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(x) as avg_x,
+STDDEV_POP(y) as stddev_y
+FROM src2
+GROUP BY category
+ORDER BY avg_x DESC]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])
++- LogicalAggregate(group=[{0}], avg_x=[AVG($1)], stddev_y=[STDDEV_POP($2)])
+ +- LogicalProject(category=[$1], x=[$2], y=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalSort(sort0=[$1], dir0=[DESC-nulls-last])
++- LogicalProject(category=[$0], avg_x=[/($1, $2)], stddev_y=[POWER(/($3, $4),
0.5:DECIMAL(2, 1))])
+ +- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)],
agg#2=[$WELFORD_M2$1($2)], agg#3=[COUNT($2)])
+ +- LogicalProject(category=[$1], x=[$2], y=[$3])
+ +- LogicalTableScan(table=[[default_catalog, default_database,
src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testAllSupportedFunctionsInSingleQuery">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+SUM(x),
+AVG(x),
+VAR_POP(x),
+VAR_SAMP(x),
+STDDEV_POP(x),
+STDDEV_SAMP(x)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[SUM($1)], EXPR$2=[AVG($1)],
EXPR$3=[VAR_POP($1)], EXPR$4=[VAR_SAMP($1)], EXPR$5=[STDDEV_POP($1)],
EXPR$6=[STDDEV_SAMP($1)])
++- LogicalProject(category=[$1], x=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[CASE(=($2, 0), null:DOUBLE, $1)],
EXPR$2=[/($3, $2)], EXPR$3=[/($4, $2)], EXPR$4=[/($4, CASE(=($2, 1),
null:BIGINT, -($2, 1)))], EXPR$5=[POWER(/($4, $2), 0.5:DECIMAL(2, 1))],
EXPR$6=[POWER(/($4, CASE(=($2, 1), null:BIGINT, -($2, 1))), 0.5:DECIMAL(2, 1))])
++- LogicalProject(category=[$0], EXPR$1=[$1], $f2=[$2], $f3=[CASE(=($2, 0),
null:DOUBLE, $1)], $f4=[$3])
+ +- LogicalAggregate(group=[{0}], EXPR$1=[$SUM0($1)], agg#1=[COUNT($1)],
agg#2=[$WELFORD_M2$1($1)])
+ +- LogicalProject(category=[$1], x=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testComplexExpressionsInAggregates">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(x + y),
+VAR_POP(x * 2),
+STDDEV_SAMP(CASE WHEN x > 0 THEN x ELSE 0 END)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[AVG($1)], EXPR$2=[VAR_POP($2)],
EXPR$3=[STDDEV_SAMP($3)])
++- LogicalProject(category=[$1], $f1=[+($2, $3)], $f2=[*($2, 2)],
$f3=[CASE(>($2, 0), $2, 0:DOUBLE)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[/($1, $2)], EXPR$2=[/($3, $4)],
EXPR$3=[POWER(/($5, CASE(=($6, 1), null:BIGINT, -($6, 1))), 0.5:DECIMAL(2, 1))])
++- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)],
agg#2=[$WELFORD_M2$1($2)], agg#3=[COUNT($2)], agg#4=[$WELFORD_M2$1($3)],
agg#5=[COUNT($3)])
+ +- LogicalProject(category=[$1], $f1=[+($2, $3)], $f2=[*($2, 2)],
$f3=[CASE(>($2, 0), $2, 0:DOUBLE)])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testComplexFilterConditions">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(x) FILTER (WHERE x > y AND category LIKE 'A%'),
+STDDEV_POP(x) FILTER (WHERE x + y > 100),
+VAR_SAMP(y) FILTER (WHERE MOD(id, 2) = 0)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[AVG($1) FILTER $2],
EXPR$2=[STDDEV_POP($1) FILTER $3], EXPR$3=[VAR_SAMP($4) FILTER $5])
++- LogicalProject(category=[$1], x=[$2], $f2=[IS TRUE(AND(>($2, $3), LIKE($1,
_UTF-16LE'A%')))], $f3=[IS TRUE(>(+($2, $3), 100))], y=[$3], $f5=[IS
TRUE(=(MOD($0, 2), 0))])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[/($1, $2)], EXPR$2=[POWER(/($3, $4),
0.5:DECIMAL(2, 1))], EXPR$3=[/($5, CASE(=($6, 1), null:BIGINT, -($6, 1)))])
++- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1) FILTER $2],
agg#1=[COUNT($1) FILTER $2], agg#2=[$WELFORD_M2$1($1) FILTER $3],
agg#3=[COUNT($1) FILTER $3], agg#4=[$WELFORD_M2$1($4) FILTER $5],
agg#5=[COUNT($4) FILTER $5])
+ +- LogicalProject(category=[$1], x=[$2], $f2=[IS TRUE(AND(>($2, $3),
LIKE($1, _UTF-16LE'A%')))], $f3=[IS TRUE(>(+($2, $3), 100))], y=[$3], $f5=[IS
TRUE(=(MOD($0, 2), 0))])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testDifferentDataTypes">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(CAST(id AS DOUBLE)),
+VAR_POP(CAST(x AS DECIMAL(20,4))),
+STDDEV_SAMP(z)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[AVG($1)], EXPR$2=[VAR_POP($2)],
EXPR$3=[STDDEV_SAMP($3)])
++- LogicalProject(category=[$1], $f1=[CAST($0):DOUBLE],
$f2=[CAST($2):DECIMAL(20, 4)], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[/($1, $2)], EXPR$2=[CAST(/($3,
$4)):DECIMAL(38, 6)], EXPR$3=[CAST(POWER(/($5, CASE(=($6, 1), null:BIGINT,
-($6, 1))), 0.5:DECIMAL(2, 1))):DECIMAL(38, 6)])
++- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($1)], agg#1=[COUNT($1)],
agg#2=[$WELFORD_M2$1($2)], agg#3=[COUNT($2)], agg#4=[$WELFORD_M2$1($3)],
agg#5=[COUNT($3)])
+ +- LogicalProject(category=[$1], $f1=[CAST($0):DOUBLE],
$f2=[CAST($2):DECIMAL(20, 4)], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testEmptyGroupBy">
+ <Resource name="sql">
+ <![CDATA[SELECT
+AVG(x),
+VAR_POP(y),
+STDDEV_SAMP(z)
+FROM src2]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{}], EXPR$0=[AVG($0)], EXPR$1=[VAR_POP($1)],
EXPR$2=[STDDEV_SAMP($2)])
++- LogicalProject(x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(EXPR$0=[/($0, $1)], EXPR$1=[/($2, $3)],
EXPR$2=[CAST(POWER(/($4, CASE(=($5, 1), null:BIGINT, -($5, 1))), 0.5:DECIMAL(2,
1))):DECIMAL(38, 6)])
++- LogicalProject($f0=[CASE(=($1, 0), null:DOUBLE, $0)], $f1=[$1], $f2=[$2],
$f3=[$3], $f4=[$4], $f5=[$5])
+ +- LogicalAggregate(group=[{}], agg#0=[$SUM0($0)], agg#1=[COUNT($0)],
agg#2=[$WELFORD_M2$1($1)], agg#3=[COUNT($1)], agg#4=[$WELFORD_M2$1($2)],
agg#5=[COUNT($2)])
+ +- LogicalProject(x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMixedAggregatesWithDistinct">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+AVG(DISTINCT x),
+VAR_POP(DISTINCT y),
+STDDEV_SAMP(DISTINCT z)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[AVG(DISTINCT $1)],
EXPR$2=[VAR_POP(DISTINCT $2)], EXPR$3=[STDDEV_SAMP(DISTINCT $3)])
++- LogicalProject(category=[$1], x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[/($1, $2)], EXPR$2=[/($3, $4)],
EXPR$3=[CAST(POWER(/($5, CASE(=($6, 1), null:BIGINT, -($6, 1))), 0.5:DECIMAL(2,
1))):DECIMAL(38, 6)])
++- LogicalProject(category=[$0], $f1=[CASE(=($2, 0), null:DOUBLE, $1)],
$f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6])
+ +- LogicalAggregate(group=[{0}], agg#0=[$SUM0(DISTINCT $1)],
agg#1=[COUNT(DISTINCT $1)], agg#2=[$WELFORD_M2$1(DISTINCT $2)],
agg#3=[COUNT(DISTINCT $2)], agg#4=[$WELFORD_M2$1(DISTINCT $3)],
agg#5=[COUNT(DISTINCT $3)])
+ +- LogicalProject(category=[$1], x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testMultipleGroupingColumns">
+ <Resource name="sql">
+ <![CDATA[SELECT category, id,
+AVG(x),
+VAR_SAMP(y),
+VAR_POP(z)
+FROM src2 GROUP BY category, id]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0, 1}], EXPR$2=[AVG($2)], EXPR$3=[VAR_SAMP($3)],
EXPR$4=[VAR_POP($4)])
++- LogicalProject(category=[$1], id=[$0], x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], id=[$1], EXPR$2=[/($2, $3)], EXPR$3=[/($4,
CASE(=($5, 1), null:BIGINT, -($5, 1)))], EXPR$4=[CAST(/($6, $7)):DECIMAL(38,
6)])
++- LogicalProject(category=[$0], id=[$1], $f2=[CASE(=($3, 0), null:DOUBLE,
$2)], $f3=[$3], $f4=[$4], $f5=[$5], $f6=[$6], $f7=[$7])
+ +- LogicalAggregate(group=[{0, 1}], agg#0=[$SUM0($2)], agg#1=[COUNT($2)],
agg#2=[$WELFORD_M2$1($3)], agg#3=[COUNT($3)], agg#4=[$WELFORD_M2$1($4)],
agg#5=[COUNT($4)])
+ +- LogicalProject(category=[$1], id=[$0], x=[$2], y=[$3], z=[$4])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
+ <TestCase name="testVarianceAndStandardDeviation">
+ <Resource name="sql">
+ <![CDATA[SELECT category,
+VAR_POP(x),
+VAR_SAMP(x),
+STDDEV_POP(x),
+STDDEV_SAMP(x)
+FROM src2 GROUP BY category]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalAggregate(group=[{0}], EXPR$1=[VAR_POP($1)], EXPR$2=[VAR_SAMP($1)],
EXPR$3=[STDDEV_POP($1)], EXPR$4=[STDDEV_SAMP($1)])
++- LogicalProject(category=[$1], x=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+LogicalProject(category=[$0], EXPR$1=[/($1, $2)], EXPR$2=[/($1, CASE(=($2, 1),
null:BIGINT, -($2, 1)))], EXPR$3=[POWER(/($1, $2), 0.5:DECIMAL(2, 1))],
EXPR$4=[POWER(/($1, CASE(=($2, 1), null:BIGINT, -($2, 1))), 0.5:DECIMAL(2, 1))])
++- LogicalAggregate(group=[{0}], agg#0=[$WELFORD_M2$1($1)], agg#1=[COUNT($1)])
+ +- LogicalProject(category=[$1], x=[$2])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src2]])
+]]>
+ </Resource>
+ </TestCase>
<TestCase name="testVarianceStddevWithFilter">
<Resource name="sql">
<![CDATA[SELECT a,
@@ -35,12 +294,11 @@ LogicalAggregate(group=[{0}], EXPR$1=[STDDEV_POP($1)
FILTER $2], EXPR$2=[STDDEV_
</Resource>
<Resource name="optimized rel plan">
<![CDATA[
-LogicalProject(a=[$0], EXPR$1=[CAST(POWER(/(-($1, /(*($2, $2), $3)), $3),
0.5:DECIMAL(2, 1))):BIGINT], EXPR$2=[CAST(POWER(/(-($4, /(*($5, $5), $6)),
CASE(=($6, 1), null:BIGINT, -($6, 1))), 0.5:DECIMAL(2, 1))):BIGINT],
EXPR$3=[/(-($7, /(*($8, $8), $9)), $9)], EXPR$4=[/(-($10, /(*($11, $11), $12)),
CASE(=($12, 1), null:BIGINT, -($12, 1)))], EXPR$5=[/($13, $14)])
-+- LogicalProject(a=[$0], $f1=[CASE(=($2, 0), null:BIGINT, $1)],
$f2=[CASE(=($4, 0), null:BIGINT, $3)], $f3=[$4], $f4=[CASE(=($6, 0),
null:BIGINT, $5)], $f5=[CASE(=($8, 0), null:BIGINT, $7)], $f6=[$8],
$f7=[CASE(=($10, 0), null:BIGINT, $9)], $f8=[CASE(=($12, 0), null:BIGINT,
$11)], $f9=[$12], $f10=[CASE(=($14, 0), null:BIGINT, $13)], $f11=[CASE(=($16,
0), null:BIGINT, $15)], $f12=[$16], $f13=[CASE(=($18, 0), null:BIGINT, $17)],
$f14=[$18])
- +- LogicalAggregate(group=[{0}], agg#0=[$SUM0($7) FILTER $2],
agg#1=[COUNT($7) FILTER $2], agg#2=[$SUM0($1) FILTER $2], agg#3=[COUNT($1)
FILTER $2], agg#4=[$SUM0($7) FILTER $3], agg#5=[COUNT($7) FILTER $3],
agg#6=[$SUM0($1) FILTER $3], agg#7=[COUNT($1) FILTER $3], agg#8=[$SUM0($7)
FILTER $4], agg#9=[COUNT($7) FILTER $4], agg#10=[$SUM0($1) FILTER $4],
agg#11=[COUNT($1) FILTER $4], agg#12=[$SUM0($7) FILTER $5], agg#13=[COUNT($7)
FILTER $5], agg#14=[$SUM0($1) FILTER $5], agg#15=[COUNT($1 [...]
- +- LogicalProject(a=[$0], b=[$1], $f2=[$2], $f3=[$3], $f4=[$4],
$f5=[$5], $f6=[$6], $f7=[*($1, $1)])
- +- LogicalProject(a=[$0], b=[$1], $f2=[IS TRUE(>($1, 10))], $f3=[IS
TRUE(>($1, 20))], $f4=[IS TRUE(>($1, 30))], $f5=[IS TRUE(>($1, 40))], $f6=[IS
TRUE(>($1, 50))])
- +- LogicalTableScan(table=[[default_catalog, default_database,
src]])
+LogicalProject(a=[$0], EXPR$1=[CAST(POWER(/($1, $2), 0.5:DECIMAL(2,
1))):BIGINT], EXPR$2=[CAST(POWER(/($3, CASE(=($4, 1), null:BIGINT, -($4, 1))),
0.5:DECIMAL(2, 1))):BIGINT], EXPR$3=[CAST(/($5, $6)):BIGINT],
EXPR$4=[CAST(/($7, CASE(=($8, 1), null:BIGINT, -($8, 1)))):BIGINT],
EXPR$5=[/($9, $10)])
++- LogicalProject(a=[$0], $f1=[$1], $f2=[$2], $f3=[$3], $f4=[$4], $f5=[$5],
$f6=[$6], $f7=[$7], $f8=[$8], $f9=[CASE(=($10, 0), null:BIGINT, $9)],
$f10=[$10])
+ +- LogicalAggregate(group=[{0}], agg#0=[$WELFORD_M2$1($1) FILTER $2],
agg#1=[COUNT($1) FILTER $2], agg#2=[$WELFORD_M2$1($1) FILTER $3],
agg#3=[COUNT($1) FILTER $3], agg#4=[$WELFORD_M2$1($1) FILTER $4],
agg#5=[COUNT($1) FILTER $4], agg#6=[$WELFORD_M2$1($1) FILTER $5],
agg#7=[COUNT($1) FILTER $5], agg#8=[$SUM0($1) FILTER $6], agg#9=[COUNT($1)
FILTER $6])
+ +- LogicalProject(a=[$0], b=[$1], $f2=[IS TRUE(>($1, 10))], $f3=[IS
TRUE(>($1, 20))], $f4=[IS TRUE(>($1, 30))], $f5=[IS TRUE(>($1, 40))], $f6=[IS
TRUE(>($1, 50))])
+ +- LogicalTableScan(table=[[default_catalog, default_database, src]])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
index 77f9b7e23f1..39dfc28b5ce 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AggregateReduceGroupingRuleTest.xml
@@ -384,10 +384,9 @@ LogicalAggregate(group=[{0, 1, 2}], EXPR$3=[COUNT($3)])
<![CDATA[
FlinkLogicalCalc(select=[a4, c4, s, EXPR$3])
+- FlinkLogicalAggregate(group=[{0, 2}], c4=[AUXILIARY_GROUP($1)],
EXPR$3=[COUNT($3)])
- +- FlinkLogicalCalc(select=[a4, c4, w$start AS s, CAST(/(-($f2, /(*($f3,
$f3), $f4)), $f4) AS INTEGER) AS b4])
- +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($1)],
agg#1=[SUM($4)], agg#2=[SUM($3)], agg#3=[COUNT($3)],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime])
- +- FlinkLogicalCalc(select=[a4, c4, d4, b4, *(b4, b4) AS $f4])
- +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
+ +- FlinkLogicalCalc(select=[a4, c4, w$start AS s, CAST(/($f2, $f3) AS
INTEGER) AS b4])
+ +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($2)],
agg#1=[$WELFORD_M2$1($1)], agg#2=[COUNT($1)], window=[TumblingGroupWindow('w$,
d4, 900000)], properties=[w$start, w$end, w$rowtime])
+ +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
</Resource>
</TestCase>
@@ -409,10 +408,9 @@ LogicalAggregate(group=[{0, 1, 2}], EXPR$3=[COUNT($3)])
<![CDATA[
FlinkLogicalCalc(select=[a4, c4, e, EXPR$3])
+- FlinkLogicalAggregate(group=[{0, 2}], c4=[AUXILIARY_GROUP($1)],
EXPR$3=[COUNT($3)])
- +- FlinkLogicalCalc(select=[a4, c4, w$end AS e, CAST(/(-($f2, /(*($f3,
$f3), $f4)), $f4) AS INTEGER) AS b4])
- +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($1)],
agg#1=[SUM($4)], agg#2=[SUM($3)], agg#3=[COUNT($3)],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime])
- +- FlinkLogicalCalc(select=[a4, c4, d4, b4, *(b4, b4) AS $f4])
- +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
+ +- FlinkLogicalCalc(select=[a4, c4, w$end AS e, CAST(/($f2, $f3) AS
INTEGER) AS b4])
+ +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($2)],
agg#1=[$WELFORD_M2$1($1)], agg#2=[COUNT($1)], window=[TumblingGroupWindow('w$,
d4, 900000)], properties=[w$start, w$end, w$rowtime])
+ +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
</Resource>
</TestCase>
@@ -433,10 +431,9 @@ LogicalAggregate(group=[{0, 1, 2}], EXPR$3=[COUNT()])
<Resource name="optimized rel plan">
<![CDATA[
FlinkLogicalAggregate(group=[{0, 1}], c4=[AUXILIARY_GROUP($2)],
EXPR$3=[COUNT()])
-+- FlinkLogicalCalc(select=[a4, CAST(/(-($f2, /(*($f3, $f3), $f4)), $f4) AS
INTEGER) AS b4, c4])
- +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($1)],
agg#1=[SUM($4)], agg#2=[SUM($3)], agg#3=[COUNT($3)],
window=[TumblingGroupWindow('w$, d4, 900000)], properties=[w$start, w$end,
w$rowtime])
- +- FlinkLogicalCalc(select=[a4, c4, d4, b4, *(b4, b4) AS $f4])
- +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
++- FlinkLogicalCalc(select=[a4, CAST(/($f2, $f3) AS INTEGER) AS b4, c4])
+ +- FlinkLogicalWindowAggregate(group=[{0}], c4=[AUXILIARY_GROUP($2)],
agg#1=[$WELFORD_M2$1($1)], agg#2=[COUNT($1)], window=[TumblingGroupWindow('w$,
d4, 900000)], properties=[w$start, w$end, w$rowtime])
+ +- FlinkLogicalLegacyTableSourceScan(table=[[default_catalog,
default_database, T4, source: [TestTableSource(a4, b4, c4, d4)]]], fields=[a4,
b4, c4, d4])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupWindowTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupWindowTest.xml
index cdd14d681b1..2ac73a6c021 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupWindowTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/sql/agg/GroupWindowTest.xml
@@ -40,10 +40,10 @@ LogicalProject(EXPR$0=[$1], EXPR$1=[$2], EXPR$2=[$3],
EXPR$3=[$4], EXPR$4=[TUMBL
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
-Calc(select=[(($f0 - (($f1 * $f1) / $f2)) / $f2) AS EXPR$0, (($f0 - (($f1 *
$f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 - 1))) AS EXPR$1,
CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / $f2), 0.5) AS BIGINT) AS EXPR$2,
CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 -
1))), 0.5) AS BIGINT) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
-+- GroupWindowAggregate(window=[TumblingGroupWindow('w$, rowtime, 900000)],
properties=[w$start, w$end, w$rowtime, w$proctime], select=[SUM($f2) AS $f0,
SUM(c) AS $f1, COUNT(c) AS $f2, start('w$) AS w$start, end('w$) AS w$end,
rowtime('w$) AS w$rowtime, proctime('w$) AS w$proctime])
+Calc(select=[CAST(($f0 / $f1) AS BIGINT) AS EXPR$0, CAST(($f0 / CASE(($f1 =
1), null:BIGINT, ($f1 - 1))) AS BIGINT) AS EXPR$1, CAST(POWER(($f0 / $f1), 0.5)
AS BIGINT) AS EXPR$2, CAST(POWER(($f0 / CASE(($f1 = 1), null:BIGINT, ($f1 -
1))), 0.5) AS BIGINT) AS EXPR$3, w$start AS EXPR$4, w$end AS EXPR$5])
++- GroupWindowAggregate(window=[TumblingGroupWindow('w$, rowtime, 900000)],
properties=[w$start, w$end, w$rowtime, w$proctime], select=[$WELFORD_M2$1(c) AS
$f0, COUNT(c) AS $f1, start('w$) AS w$start, end('w$) AS w$end, rowtime('w$) AS
w$rowtime, proctime('w$) AS w$proctime])
+- Exchange(distribution=[single])
- +- Calc(select=[rowtime, c, (c * c) AS $f2])
+ +- Calc(select=[rowtime, c])
+- DataStreamScan(table=[[default_catalog, default_database,
MyTable]], fields=[a, b, c, proctime, rowtime])
]]>
</Resource>
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/GroupWindowTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/GroupWindowTest.xml
index 8562c253d03..0fa3c70952f 100644
---
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/GroupWindowTest.xml
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/stream/table/GroupWindowTest.xml
@@ -170,11 +170,10 @@ LogicalProject(EXPR$0=[$0], EXPR$1=[$1], EXPR$2=[$2],
EXPR$3=[$3], EXPR$4=[$4],
</Resource>
<Resource name="optimized exec plan">
<![CDATA[
-Calc(select=[(($f0 - (($f1 * $f1) / $f2)) / $f2) AS EXPR$0, (($f0 - (($f1 *
$f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 - 1))) AS EXPR$1,
CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / $f2), 0.5) AS BIGINT) AS EXPR$2,
CAST(POWER((($f0 - (($f1 * $f1) / $f2)) / CASE(($f2 = 1), null:BIGINT, ($f2 -
1))), 0.5) AS BIGINT) AS EXPR$3, EXPR$4, EXPR$5])
-+- GroupWindowAggregate(window=[TumblingGroupWindow('w, rowtime, 900000)],
properties=[EXPR$4, EXPR$5], select=[SUM($f4) AS $f0, SUM(c) AS $f1, COUNT(c)
AS $f2, start('w) AS EXPR$4, end('w) AS EXPR$5])
+Calc(select=[CAST(($f0 / $f1) AS BIGINT) AS EXPR$0, CAST(($f0 / CASE(($f1 =
1), null:BIGINT, ($f1 - 1))) AS BIGINT) AS EXPR$1, CAST(POWER(($f0 / $f1), 0.5)
AS BIGINT) AS EXPR$2, CAST(POWER(($f0 / CASE(($f1 = 1), null:BIGINT, ($f1 -
1))), 0.5) AS BIGINT) AS EXPR$3, EXPR$4, EXPR$5])
++- GroupWindowAggregate(window=[TumblingGroupWindow('w, rowtime, 900000)],
properties=[EXPR$4, EXPR$5], select=[$WELFORD_M2$1(c) AS $f0, COUNT(c) AS $f1,
start('w) AS EXPR$4, end('w) AS EXPR$5])
+- Exchange(distribution=[single])
- +- Calc(select=[rowtime, a, b, c, (c * c) AS $f4])
- +- DataStreamScan(table=[[default_catalog, default_database, T1]],
fields=[rowtime, a, b, c])
+ +- DataStreamScan(table=[[default_catalog, default_database, T1]],
fields=[rowtime, a, b, c])
]]>
</Resource>
</TestCase>
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
index 206d25f86e9..ab009185050 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/agg/AggregateITCaseBase.scala
@@ -34,12 +34,18 @@ import org.apache.flink.types.Row
import org.assertj.core.api.Assertions.assertThatThrownBy
import org.junit.jupiter.api.{BeforeEach, Test}
+import org.junit.jupiter.api.condition.DisabledIf
/** Aggregate IT case base class. */
abstract class AggregateITCaseBase(testName: String) extends BatchTestBase {
def prepareAggOp(): Unit
+ def isHashAggITCase: Boolean = {
+ // Disable variance related tests for hash agg, because the acc of
WELFORD_M2 is not fixed length.
+ this.isInstanceOf[HashAggITCase]
+ }
+
@BeforeEach
override def before(): Unit = {
super.before()
@@ -772,6 +778,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
)
}
+ @DisabledIf("isHashAggITCase")
@Test
def testStdDev(): Unit = {
// NOTE: if f0 is INT type, our stddev functions return INT.
@@ -782,6 +789,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
)
}
+ @DisabledIf("isHashAggITCase")
@Test
def test1RowStdDev(): Unit = {
checkQuery(
@@ -790,6 +798,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
Seq((0.0, null, null)))
}
+ @DisabledIf("isHashAggITCase")
@Test
def testVariance(): Unit = {
checkQuery(
@@ -798,6 +807,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
Seq((0.25, 0.5, 0.5)))
}
+ @DisabledIf("isHashAggITCase")
@Test
def test1RowVariance(): Unit = {
checkQuery(
@@ -806,6 +816,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
Seq((0.0, null, null)))
}
+ @DisabledIf("isHashAggITCase")
@Test
def testZeroStdDev(): Unit = {
val emptyTable = Seq[(Int, Int)]()
@@ -836,6 +847,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
)
}
+ @DisabledIf("isHashAggITCase")
@Test
def testMoments(): Unit = {
checkQuery(
@@ -846,6 +858,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
// todo: Spark has skewness() and kurtosis()
}
+ @DisabledIf("isHashAggITCase")
@Test
def testZeroMoments(): Unit = {
checkQuery(
@@ -856,6 +869,7 @@ abstract class AggregateITCaseBase(testName: String)
extends BatchTestBase {
// todo: Spark returns Double.NaN instead of null
}
+ @DisabledIf("isHashAggITCase")
@Test
def testNullMoments(): Unit = {
checkQuery(
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/table/AggregationITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/table/AggregationITCase.scala
index 6adc96620af..88a26d49f1e 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/table/AggregationITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/batch/table/AggregationITCase.scala
@@ -394,12 +394,12 @@ class AggregationITCase extends BatchTestBase {
val expected =
"0,0,0," +
"0,0.5,0.5,0.500000000000000000," +
- "1,1,1," +
- "1,0.70710677,0.7071067811865476,0.707106781186547600," +
+ "0,0,0," +
+ "0,0.70710677,0.7071067811865476,0.707106781186547600," +
"0,0,0," +
"0,0.25,0.25,0.250000000000000000," +
- "1,1,1," +
- "1,0.5,0.5,0.500000000000000000"
+ "0,0,0," +
+ "0,0.5,0.5,0.500000000000000000"
val results = executeQuery(res)
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
diff --git
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
index 2ae87808ce9..3d0e189294e 100644
---
a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
+++
b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/runtime/stream/sql/AggregateITCase.scala
@@ -1566,7 +1566,7 @@ class AggregateITCase(
tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink)
env.execute()
// TODO: define precise behavior of VAR_POP()
- val expected = List(15602500.toString, 28889.toString)
+ val expected = List(15602500.toString, 28888.toString)
assertThat(sink.getRetractResults.sorted).isEqualTo(expected.sorted)
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/WelfordM2AggFunction.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/WelfordM2AggFunction.java
new file mode 100644
index 00000000000..82e6f324f62
--- /dev/null
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/functions/aggregate/WelfordM2AggFunction.java
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.functions.aggregate;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.data.DecimalData;
+import org.apache.flink.table.data.DecimalDataUtils;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.LogicalType;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
+
+import static
org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType;
+
+/**
+ * Internal built-in WELFORD_M2 aggregate function to calculate the m2 term in
Welford's online
+ * algorithm. This is a helper function for rewriting variance related
functions.
+ *
+ * @see <a
+ *
href="https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm">Welford's
+ * online algorithm</a>
+ */
+@Internal
+public abstract class WelfordM2AggFunction
+ extends BuiltInAggregateFunction<Double,
WelfordM2AggFunction.WelfordM2Accumulator> {
+
+ protected final transient DataType valueType;
+
+ public WelfordM2AggFunction(LogicalType inputType) {
+ this.valueType = toInternalDataType(inputType);
+ }
+
+ @Override
+ public List<DataType> getArgumentDataTypes() {
+ return Collections.singletonList(valueType);
+ }
+
+ @Override
+ public DataType getAccumulatorDataType() {
+ return DataTypes.STRUCTURED(
+ WelfordM2Accumulator.class,
+ DataTypes.FIELD("n", DataTypes.BIGINT().notNull()),
+ DataTypes.FIELD("mean", DataTypes.DOUBLE().notNull()),
+ DataTypes.FIELD("m2", DataTypes.DOUBLE().notNull()));
+ }
+
+ @Override
+ public DataType getOutputDataType() {
+ return DataTypes.DOUBLE();
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Accumulator
+ //
--------------------------------------------------------------------------------------------
+
+ /** Accumulator for WELFORD_M2. */
+ public static class WelfordM2Accumulator {
+
+ public long n = 0L;
+ public double mean = 0.0D;
+ public double m2 = 0.0D;
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj) {
+ return true;
+ }
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ WelfordM2Accumulator that = (WelfordM2Accumulator) obj;
+ return n == that.n
+ && Double.compare(mean, that.mean) == 0
+ && Double.compare(m2, that.m2) == 0;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(n, mean, m2);
+ }
+ }
+
+ @Override
+ public WelfordM2Accumulator createAccumulator() {
+ return new WelfordM2Accumulator();
+ }
+
+ public void resetAccumulator(WelfordM2Accumulator acc) {
+ acc.n = 0L;
+ acc.mean = 0.0D;
+ acc.m2 = 0.0D;
+ }
+
+ @Override
+ public Double getValue(WelfordM2Accumulator acc) {
+ return acc.n <= 0 || acc.m2 < 0 ? null : acc.m2;
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Runtime
+ //
--------------------------------------------------------------------------------------------
+
+ protected abstract double doubleValue(Object value);
+
+ public void accumulate(WelfordM2Accumulator acc, @Nullable Object value) {
+ if (value == null) {
+ return;
+ }
+
+ acc.n += 1;
+ // Ignore accumulate when acc.n <= 0, but keep counting to align with
the outside COUNT.
+ if (acc.n <= 0) {
+ return;
+ }
+
+ double val = doubleValue(value);
+ double delta = val - acc.mean;
+ acc.mean += delta / acc.n;
+ double delta2 = val - acc.mean;
+ acc.m2 += delta * delta2;
+ }
+
+ public void retract(WelfordM2Accumulator acc, @Nullable Object value) {
+ if (value == null) {
+ return;
+ }
+
+ acc.n -= 1;
+ // Ignore accumulate when acc.n <= 0, but keep counting to align with
the outside COUNT.
+ if (acc.n <= 0) {
+ if (acc.n == 0) {
+ acc.mean = 0.0D;
+ acc.m2 = 0.0D;
+ }
+ return;
+ }
+
+ double val = doubleValue(value);
+ double delta2 = val - acc.mean;
+ acc.mean -= delta2 / acc.n;
+ double delta = val - acc.mean;
+ acc.m2 -= delta * delta2;
+ }
+
+ public void merge(WelfordM2Accumulator acc, Iterable<WelfordM2Accumulator>
its) {
+ // Ignore acc with negative acc.n because it is invalid, but keep
counting to align with the
+ // outside COUNT.
+ // Merge negativeSum to acc.n at last to avoid data loss caused by
intermediate negative
+ // total count.
+ long negativeSum = 0;
+ for (WelfordM2Accumulator other : its) {
+ if (other.n <= 0) {
+ negativeSum += other.n;
+ continue;
+ }
+
+ if (acc.n == 0) {
+ acc.n = other.n;
+ acc.mean = other.mean;
+ acc.m2 = other.m2;
+ continue;
+ }
+
+ long newCount = acc.n + other.n;
+ double deltaMean = other.mean - acc.mean;
+ double newMean = acc.mean + (double) other.n / newCount *
deltaMean;
+ double newM2 =
+ acc.m2 + other.m2 + (double) acc.n * other.n / newCount *
deltaMean * deltaMean;
+
+ acc.n = newCount;
+ acc.mean = newMean;
+ acc.m2 = newM2;
+ }
+
+ acc.n += negativeSum;
+ if (acc.n <= 0) {
+ acc.mean = 0.0D;
+ acc.m2 = 0.0D;
+ }
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // Sub-classes
+ //
--------------------------------------------------------------------------------------------
+
+ /** Implementation for numeric types excluding DECIMAL. */
+ public static class NumberFunction extends WelfordM2AggFunction {
+
+ public NumberFunction(LogicalType inputType) {
+ super(inputType);
+ }
+
+ @Override
+ protected double doubleValue(Object value) {
+ return ((Number) value).doubleValue();
+ }
+ }
+
+ /** Implementation for DECIMAL. */
+ public static class DecimalFunction extends WelfordM2AggFunction {
+
+ public DecimalFunction(LogicalType inputType) {
+ super(inputType);
+ }
+
+ @Override
+ protected double doubleValue(Object value) {
+ return DecimalDataUtils.doubleValue((DecimalData) value);
+ }
+ }
+}