This is an automated email from the ASF dual-hosted git repository.
snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 61f9ffece97 [FLINK-37002][table] Migrate `DecomposeGroupingSetsRule`
to java
61f9ffece97 is described below
commit 61f9ffece974134c8d7da8f87e3092b54af9a700
Author: Jacky Lau <[email protected]>
AuthorDate: Sun Jan 18 06:25:04 2026 +0800
[FLINK-37002][table] Migrate `DecomposeGroupingSetsRule` to java
---------
Co-authored-by: yongliu <[email protected]>
---
.../rules/logical/DecomposeGroupingSetsRule.java | 428 +++++++++++++++++++++
.../rules/logical/DecomposeGroupingSetsRule.scala | 338 ----------------
.../planner/utils/JavaScalaConversionUtil.scala | 8 +-
3 files changed, 435 insertions(+), 339 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java
new file mode 100644
index 00000000000..30cea62f71f
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.java
@@ -0,0 +1,428 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.table.api.TableException;
+import org.apache.flink.table.planner.calcite.FlinkRelBuilder;
+import org.apache.flink.table.planner.calcite.FlinkRelFactories;
+import org.apache.flink.table.planner.plan.utils.AggregateUtil;
+import org.apache.flink.table.planner.plan.utils.ExpandUtil;
+import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.Stream;
+
+/**
+ * This rule rewrites an aggregation query with grouping sets into an regular
aggregation query with
+ * expand.
+ *
+ * <p>This rule duplicates the input data by two or more times (# number of
groupSets + an optional
+ * non-distinct group). This will put quite a bit of memory pressure of the
used aggregate and
+ * exchange operators.
+ *
+ * <p>This rule will be used for the plan with grouping sets or the plan with
distinct aggregations
+ * after {@link FlinkAggregateExpandDistinctAggregatesRule} applied.
+ *
+ * <p>`FlinkAggregateExpandDistinctAggregatesRule` rewrites an aggregate query
with distinct
+ * aggregations into an expanded double aggregation. The first aggregate has
grouping sets in which
+ * the regular aggregation expressions and every distinct clause are
aggregated in a separate group.
+ * The results are then combined in a second aggregate.
+ *
+ * <pre>Examples:
+ *
+ * MyTable: a: INT, b: BIGINT, c: VARCHAR(32), d: VARCHAR(32)
+ *
+ * Original records:
+ * | a | b | c | d |
+ * |:-:|:-:|:--:|:--:|
+ * | 1 | 1 | c1 | d1 |
+ * | 1 | 2 | c1 | d2 |
+ * | 2 | 1 | c1 | d1 |
+ *
+ * Example1 (expand for DISTINCT aggregates):
+ *
+ * SQL: SELECT a, SUM(DISTINCT b) as t1, COUNT(DISTINCT c) as t2, COUNT(d) as
t3 FROM MyTable GROUP
+ * BY a
+ *
+ * Logical plan:
+ * {@code
+ * LogicalAggregate(group=[{0}], t1=[SUM(DISTINCT $1)], t2=[COUNT(DISTINCT
$2)], t3=[COUNT($3)])
+ * LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after `FlinkAggregateExpandDistinctAggregatesRule` applied:
+ * {@code
+ * LogicalProject(a=[$0], t1=[$1], t2=[$2], t3=[CAST($3):BIGINT NOT NULL])
+ * LogicalProject(a=[$0], t1=[$1], t2=[$2], $f3=[CASE(IS NOT NULL($3), $3,
0)])
+ * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2)
FILTER $5],
+ * t3=[MIN($3) FILTER $6])
+ * LogicalProject(a=[$0], b=[$1], c=[$2], t3=[$3], $g_1=[=($4, 1)],
$g_2=[=($4, 2)],
+ * $g_3=[=($4, 3)])
+ * LogicalAggregate(group=[{0, 1, 2}], groups=[[{0, 1}, {0, 2}, {0}]],
t3=[COUNT($3)],
+ * $g=[GROUPING($0, $1, $2)])
+ * LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after this rule applied:
+ * {@code
+ * LogicalCalc(expr#0..3=[{inputs}], expr#4=[IS NOT NULL($t3)], ...)
+ * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER
$5],
+ * t3=[MIN($3) FILTER $6])
+ * LogicalCalc(expr#0..4=[{inputs}], ... expr#10=[CASE($t6, $t5, $t8, $t7,
$t9)],
+ * expr#11=[1], expr#12=[=($t10, $t11)], ... $g_1=[$t12], ...)
+ * LogicalAggregate(group=[{0, 1, 2, 4}], groups=[[]], t3=[COUNT($3)])
+ * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], d=[$3], $e=[1]},
+ * {a=[$0], b=[null], c=[$2], d=[$3], $e=[2]}, {a=[$0], b=[null],
c=[null], d=[$3], $e=[3]}])
+ * LogicalTableSourceScan(table=[[builtin, default, MyTable]], fields=[a,
b, c, d])
+ * }
+ *
+ * '$e = 1' is equivalent to 'group by a, b' '$e = 2' is equivalent to 'group
by a, c' '$e = 3' is
+ * equivalent to 'group by a'
+ *
+ * Expanded records: \+-----+-----+-----+-----+-----+ \| a | b | c | d | $e |
+ * \+-----+-----+-----+-----+-----+ ---+--- \| 1 | 1 | null| d1 | 1 | |
+ * \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d1 | 2 | records
expanded by record1
+ * \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d1 | 3 | |
\+-----+-----+-----+-----+-----+
+ * ---+--- \| 1 | 2 | null| d2 | 1 | | \+-----+-----+-----+-----+-----+ | \| 1
| null| c1 | d2 | 2 |
+ * records expanded by record2 \+-----+-----+-----+-----+-----+ | \| 1 | null|
null| d2 | 3 | |
+ * \+-----+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| d1 | 1 | |
+ * \+-----+-----+-----+-----+-----+ | \| 2 | null| c1 | d1 | 2 | records
expanded by record3
+ * \+-----+-----+-----+-----+-----+ | \| 2 | null| null| d1 | 3 | |
\+-----+-----+-----+-----+-----+
+ * ---+---
+ *
+ * Example2 (Some fields are both in DISTINCT aggregates and non-DISTINCT
aggregates):
+ *
+ * SQL: SELECT MAX(a) as t1, COUNT(DISTINCT a) as t2, count(DISTINCT d) as t3
FROM MyTable
+ *
+ * Field `a` is both in DISTINCT aggregate and `MAX` aggregate, so, `a` should
be outputted as two
+ * individual fields, one is for `MAX` aggregate, another is for DISTINCT
aggregate.
+ *
+ * Expanded records: \+-----+-----+-----+-----+ \| a | d | $e | a_0 |
\+-----+-----+-----+-----+
+ * ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null| d1 | 2
| 1 | records
+ * expanded by record1 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | |
+ * \+-----+-----+-----+-----+ ---+--- \| 1 | null| 1 | 1 | |
\+-----+-----+-----+-----+ | \| null|
+ * d2 | 2 | 1 | records expanded by record2 \+-----+-----+-----+-----+ | \|
null| null| 3 | 1 | |
+ * \+-----+-----+-----+-----+ ---+--- \| 2 | null| 1 | 2 | |
\+-----+-----+-----+-----+ | \| null|
+ * d1 | 2 | 2 | records expanded by record3 \+-----+-----+-----+-----+ | \|
null| null| 3 | 2 | |
+ * \+-----+-----+-----+-----+ ---+---
+ *
+ * Example3 (expand for CUBE/ROLLUP/GROUPING SETS):
+ *
+ * SQL: SELECT a, c, SUM(b) as b FROM MyTable GROUP BY GROUPING SETS (a, c)
+ *
+ * Logical plan:
+ * {@code
+ * LogicalAggregate(group=[{0, 1}], groups=[[{0}, {1}]], b=[SUM($2)])
+ * LogicalProject(a=[$0], c=[$2], b=[$1])
+ * LogicalTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * Logical plan after this rule applied:
+ * {@code
+ * LogicalCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}], b=[$t3])
+ * LogicalAggregate(group=[{0, 2, 3}], groups=[[]], b=[SUM($1)])
+ * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], $e=[1]},
+ * {a=[null], b=[$1], c=[$2], $e=[2]}])
+ * LogicalNativeTableScan(table=[[builtin, default, MyTable]])
+ * }
+ *
+ * '$e = 1' is equivalent to 'group by a' '$e = 2' is equivalent to 'group by
c'
+ *
+ * Expanded records: \+-----+-----+-----+-----+ \| a | b | c | $e |
\+-----+-----+-----+-----+
+ * ---+--- \| 1 | 1 | null| 1 | | \+-----+-----+-----+-----+ records expanded
by record1 \| null| 1
+ * \| c1 | 2 | | \+-----+-----+-----+-----+ ---+--- \| 1 | 2 | null| 1 | |
+ * \+-----+-----+-----+-----+ records expanded by record2 \| null| 2 | c1 | 2
| |
+ * \+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| 1 | |
\+-----+-----+-----+-----+ records
+ * expanded by record3 \| null| 1 | c1 | 2 | | \+-----+-----+-----+-----+
---+---
+ * </pre>
+ */
[email protected]
+public class DecomposeGroupingSetsRule
+ extends
RelRule<DecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig> {
+ public static final DecomposeGroupingSetsRule INSTANCE =
+
DecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig.DEFAULT.toRule();
+
+ protected DecomposeGroupingSetsRule(DecomposeGroupingSetsRuleConfig
config) {
+ super(config);
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ LogicalAggregate agg = call.rel(0);
+ List<Object> groupIdExprs =
+ JavaScalaConversionUtil.toJava(
+ AggregateUtil.getGroupIdExprIndexes(
+
JavaScalaConversionUtil.toScala(agg.getAggCallList())));
+ return agg.getGroupSets().size() > 1 || !groupIdExprs.isEmpty();
+ }
+
+ public void onMatch(RelOptRuleCall call) {
+ LogicalAggregate agg = call.rel(0);
+ // Long data type is used to store groupValue in
FlinkAggregateExpandDistinctAggregatesRule,
+ // and the result of grouping function is a positive value,
+ // so the max groupCount must be less than 64.
+ if (agg.getGroupCount() >= 64) {
+ throw new TableException("group count must be less than 64.");
+ }
+
+ RelNode aggInput = agg.getInput();
+ List<Object> groupIdExprs =
+ JavaScalaConversionUtil.toJava(
+ AggregateUtil.getGroupIdExprIndexes(
+
JavaScalaConversionUtil.toScala(agg.getAggCallList())));
+ List<Tuple2<AggregateCall, Integer>> aggCallsWithIndexes =
+ IntStream.range(0, agg.getAggCallList().size())
+ .mapToObj(i -> Tuple2.of(agg.getAggCallList().get(i),
i))
+ .collect(Collectors.toList());
+
+ RelOptCluster cluster = agg.getCluster();
+ RexBuilder rexBuilder = cluster.getRexBuilder();
+ boolean needExpand = agg.getGroupSets().size() > 1;
+
+ FlinkRelBuilder relBuilder = (FlinkRelBuilder) call.builder();
+ relBuilder.push(aggInput);
+
+ ImmutableBitSet newGroupSet;
+ Map<Integer, Integer> duplicateFieldMap;
+ if (needExpand) {
+ Tuple2<scala.collection.immutable.Map<Integer, Integer>, Integer>
expandResult =
+ JavaScalaConversionUtil.toJava(
+ ExpandUtil.buildExpandNode(
+ relBuilder,
+
JavaScalaConversionUtil.toScala(agg.getAggCallList()),
+ agg.getGroupSet(),
+ agg.getGroupSets()));
+
+ // new groupSet contains original groupSet and expand_id('$e')
field
+ newGroupSet =
agg.getGroupSet().union(ImmutableBitSet.of(expandResult.f1));
+ duplicateFieldMap =
JavaScalaConversionUtil.toJava(expandResult.f0);
+ } else {
+ // no need add expand node, only need care about group functions
+ newGroupSet = agg.getGroupSet();
+ duplicateFieldMap = new HashMap<>();
+ }
+
+ int newGroupCount = newGroupSet.cardinality();
+ List<AggregateCall> newAggCalls =
+ aggCallsWithIndexes.stream()
+ .filter(p -> !groupIdExprs.contains(p.f1))
+ .map(
+ p -> {
+ AggregateCall aggCall = p.f0;
+ List<Integer> newArgList =
+ aggCall.getArgList().stream()
+ .map(a ->
duplicateFieldMap.getOrDefault(a, a))
+
.collect(Collectors.toList());
+ int newFilterArg =
+ duplicateFieldMap.getOrDefault(
+ aggCall.filterArg,
aggCall.filterArg);
+ return aggCall.adaptTo(
+ relBuilder.peek(),
+ newArgList,
+ newFilterArg,
+ agg.getGroupCount(),
+ newGroupCount);
+ })
+ .collect(Collectors.toList());
+
+ // create simple aggregate
+ relBuilder.aggregate(relBuilder.groupKey(newGroupSet,
List.of(newGroupSet)), newAggCalls);
+ RelNode newAgg = relBuilder.peek();
+
+ // create a project to mapping original aggregate's output
+ // get names of original grouping fields
+ List<String> groupingFieldsName =
+ IntStream.range(0, agg.getGroupCount())
+ .mapToObj(x -> agg.getRowType().getFieldNames().get(x))
+ .collect(Collectors.toList());
+
+ // create field access for all original grouping fields
+ List<RexNode> groupingFields =
+ IntStream.range(0, agg.getGroupSet().cardinality())
+ .mapToObj(idx -> rexBuilder.makeInputRef(newAgg, idx))
+ .collect(Collectors.toList());
+
+ List<Tuple2<ImmutableBitSet, Integer>> groupSetsWithIndexes =
+ IntStream.range(0, agg.getGroupSets().size())
+ .mapToObj(i -> Tuple2.of(agg.getGroupSets().get(i), i))
+ .collect(Collectors.toList());
+ // output aggregate calls including `normal` agg call and grouping agg
call
+ int aggCnt = 0;
+ List<RexNode> aggFields = new ArrayList<>();
+ for (Tuple2<AggregateCall, Integer> aggCallWithIndex :
aggCallsWithIndexes) {
+ AggregateCall aggCall = aggCallWithIndex.f0;
+ int idx = aggCallWithIndex.f1;
+ if (groupIdExprs.contains(idx)) {
+ if (needExpand) {
+ // reference to expand_id('$e') field in new aggregate
+ int expandIdIdxInNewAgg = newGroupCount - 1;
+ RexInputRef expandIdField =
+ rexBuilder.makeInputRef(newAgg,
expandIdIdxInNewAgg);
+ // create case when for group expression
+ List<RexNode> whenThenElse = new ArrayList<>();
+ for (int i = 0; i < groupSetsWithIndexes.size(); i++) {
+ Tuple2<ImmutableBitSet, Integer> tuple =
groupSetsWithIndexes.get(i);
+ RexNode groupExpr =
+ lowerGroupExpr(rexBuilder, aggCall,
groupSetsWithIndexes, i);
+ if (i < groupSetsWithIndexes.size() - 1) {
+ // WHEN/THEN
+ long expandIdVal =
ExpandUtil.genExpandId(agg.getGroupSet(), tuple.f0);
+ RelDataType expandIdType =
+ newAgg.getRowType()
+ .getFieldList()
+ .get(expandIdIdxInNewAgg)
+ .getType();
+ RexNode expandIdLit =
+ rexBuilder.makeLiteral(expandIdVal,
expandIdType, false);
+ RexNode condition =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.EQUALS,
expandIdField, expandIdLit);
+ // when $e = $e_value
+ whenThenElse.add(condition);
+ // then return group expression literal value
+ whenThenElse.add(groupExpr);
+ } else {
+ // else return group expression literal value
+ whenThenElse.add(groupExpr);
+ }
+ }
+
aggFields.add(rexBuilder.makeCall(SqlStdOperatorTable.CASE, whenThenElse));
+ } else {
+ // create literal for group expression
+ aggFields.add(lowerGroupExpr(rexBuilder, aggCall,
groupSetsWithIndexes, 0));
+ }
+ } else {
+ // create access to aggregation result
+ RexInputRef aggResult = rexBuilder.makeInputRef(newAgg,
newGroupCount + aggCnt);
+ aggCnt++;
+ aggFields.add(aggResult);
+ }
+ }
+
+ // add a projection to establish the result schema and set the values
of the group
+ // expressions.
+ RelNode project =
+ relBuilder
+ .project(
+ Stream.concat(groupingFields.stream(),
aggFields.stream())
+ .collect(Collectors.toList()),
+ Stream.concat(
+ groupingFieldsName.stream(),
+ agg.getAggCallList().stream()
+
.map(AggregateCall::getName))
+ .collect(Collectors.toList()))
+ .convert(agg.getRowType(), true)
+ .build();
+
+ call.transformTo(project);
+ }
+
+ /** Returns a literal for a given group expression. */
+ private RexNode lowerGroupExpr(
+ RexBuilder builder,
+ AggregateCall call,
+ List<Tuple2<ImmutableBitSet, Integer>> groupSetsWithIndexes,
+ int indexInGroupSets) {
+
+ ImmutableBitSet groupSet =
groupSetsWithIndexes.get(indexInGroupSets).f0;
+ Set<Integer> groups = groupSet.asSet();
+
+ switch (call.getAggregation().getKind()) {
+ case GROUP_ID:
+ // https://issues.apache.org/jira/browse/CALCITE-1824
+ // GROUP_ID is not in the SQL standard. It is implemented only
by Oracle.
+ // GROUP_ID is useful only if you have duplicate grouping sets,
+ // If grouping sets are distinct, GROUP_ID() will always
return zero;
+ // Else return the index in the duplicate grouping sets.
+ // e.g. SELECT deptno, GROUP_ID() AS g FROM Emp GROUP BY
GROUPING SETS (deptno, (),
+ // ())
+ // As you can see, the grouping set () occurs twice.
+ // So there is one row in the result for each occurrence:
+ // the first occurrence has g = 0; the second has g = 1.
+
+ List<Integer> duplicateGroupSetsIndices =
+ groupSetsWithIndexes.stream()
+ .filter(p -> p.f0.compareTo(groupSet) == 0)
+ .map(tuple2 -> tuple2.f1)
+ .collect(Collectors.toList());
+ Preconditions.checkArgument(
+ !duplicateGroupSetsIndices.isEmpty(), "requirement
failed");
+ long id = duplicateGroupSetsIndices.indexOf(indexInGroupSets);
+ return builder.makeLiteral(id, call.getType(), false);
+ case GROUPING:
+ case GROUPING_ID:
+ // GROUPING function is defined in the SQL standard,
+ // but the definition of GROUPING is different from in Oracle
and in SQL standard:
+ //
https://docs.oracle.com/cd/B28359_01/server.111/b28286/functions064.htm#SQLRF00647
+ //
+ // GROUPING_ID function is not defined in the SQL standard,
and has the same
+ // functionality with GROUPING function in Calcite.
+ // our implementation is consistent with Oracle about
GROUPING_ID function.
+ //
+ // NOTES:
+ // In Calcite, the java-document of SqlGroupingFunction is not
consistent with
+ // agg.iq.
+ long res = 0L;
+ for (Integer arg : call.getArgList()) {
+ res = (res << 1L) + (groups.contains(arg) ? 0L : 1L);
+ }
+ return builder.makeLiteral(res, call.getType(), false);
+ default:
+ return builder.makeNullLiteral(call.getType());
+ }
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable(singleton = false)
+ public interface DecomposeGroupingSetsRuleConfig extends RelRule.Config {
+ DecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig DEFAULT =
+
ImmutableDecomposeGroupingSetsRule.DecomposeGroupingSetsRuleConfig.builder()
+ .operandSupplier(b0 ->
b0.operand(LogicalAggregate.class).anyInputs())
+
.relBuilderFactory(FlinkRelFactories.FLINK_REL_BUILDER())
+ .description("DecomposeGroupingSetsRule")
+ .build();
+
+ @Override
+ default DecomposeGroupingSetsRule toRule() {
+ return new DecomposeGroupingSetsRule(this);
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala
deleted file mode 100644
index 140c1d7c92c..00000000000
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/DecomposeGroupingSetsRule.scala
+++ /dev/null
@@ -1,338 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.rules.logical
-
-import org.apache.flink.table.api.TableException
-import org.apache.flink.table.planner.calcite.{FlinkRelBuilder,
FlinkRelFactories}
-import org.apache.flink.table.planner.plan.utils.{AggregateUtil, ExpandUtil}
-
-import com.google.common.collect.ImmutableList
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall}
-import org.apache.calcite.plan.RelOptRule._
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rel.logical.LogicalAggregate
-import org.apache.calcite.rex.{RexBuilder, RexNode}
-import org.apache.calcite.sql.SqlKind
-import org.apache.calcite.sql.fun.SqlStdOperatorTable
-import org.apache.calcite.util.ImmutableBitSet
-
-import scala.collection.JavaConversions._
-
-/**
- * This rule rewrites an aggregation query with grouping sets into an regular
aggregation query with
- * expand.
- *
- * This rule duplicates the input data by two or more times (# number of
groupSets + an optional
- * non-distinct group). This will put quite a bit of memory pressure of the
used aggregate and
- * exchange operators.
- *
- * This rule will be used for the plan with grouping sets or the plan with
distinct aggregations
- * after [[FlinkAggregateExpandDistinctAggregatesRule]] applied.
- *
- * `FlinkAggregateExpandDistinctAggregatesRule` rewrites an aggregate query
with distinct
- * aggregations into an expanded double aggregation. The first aggregate has
grouping sets in which
- * the regular aggregation expressions and every distinct clause are
aggregated in a separate group.
- * The results are then combined in a second aggregate.
- *
- * Examples:
- *
- * MyTable: a: INT, b: BIGINT, c: VARCHAR(32), d: VARCHAR(32)
- *
- * Original records:
- * | a | b | c | d |
- * |:-:|:-:|:--:|:--:|
- * | 1 | 1 | c1 | d1 |
- * | 1 | 2 | c1 | d2 |
- * | 2 | 1 | c1 | d1 |
- *
- * Example1 (expand for DISTINCT aggregates):
- *
- * SQL: SELECT a, SUM(DISTINCT b) as t1, COUNT(DISTINCT c) as t2, COUNT(d) as
t3 FROM MyTable GROUP
- * BY a
- *
- * Logical plan:
- * {{{
- * LogicalAggregate(group=[{0}], t1=[SUM(DISTINCT $1)], t2=[COUNT(DISTINCT
$2)], t3=[COUNT($3)])
- * LogicalTableScan(table=[[builtin, default, MyTable]])
- * }}}
- *
- * Logical plan after `FlinkAggregateExpandDistinctAggregatesRule` applied:
- * {{{
- * LogicalProject(a=[$0], t1=[$1], t2=[$2], t3=[CAST($3):BIGINT NOT NULL])
- * LogicalProject(a=[$0], t1=[$1], t2=[$2], $f3=[CASE(IS NOT NULL($3), $3,
0)])
- * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2)
FILTER $5],
- * t3=[MIN($3) FILTER $6])
- * LogicalProject(a=[$0], b=[$1], c=[$2], t3=[$3], $g_1=[=($4, 1)],
$g_2=[=($4, 2)],
- * $g_3=[=($4, 3)])
- * LogicalAggregate(group=[{0, 1, 2}], groups=[[{0, 1}, {0, 2}, {0}]],
t3=[COUNT($3)],
- * $g=[GROUPING($0, $1, $2)])
- * LogicalTableScan(table=[[builtin, default, MyTable]])
- * }}}
- *
- * Logical plan after this rule applied:
- * {{{
- * LogicalCalc(expr#0..3=[{inputs}], expr#4=[IS NOT NULL($t3)], ...)
- * LogicalAggregate(group=[{0}], t1=[SUM($1) FILTER $4], t2=[COUNT($2) FILTER
$5],
- * t3=[MIN($3) FILTER $6])
- * LogicalCalc(expr#0..4=[{inputs}], ... expr#10=[CASE($t6, $t5, $t8, $t7,
$t9)],
- * expr#11=[1], expr#12=[=($t10, $t11)], ... $g_1=[$t12], ...)
- * LogicalAggregate(group=[{0, 1, 2, 4}], groups=[[]], t3=[COUNT($3)])
- * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], d=[$3], $e=[1]},
- * {a=[$0], b=[null], c=[$2], d=[$3], $e=[2]}, {a=[$0], b=[null],
c=[null], d=[$3], $e=[3]}])
- * LogicalTableSourceScan(table=[[builtin, default, MyTable]], fields=[a,
b, c, d])
- * }}}
- *
- * '$e = 1' is equivalent to 'group by a, b' '$e = 2' is equivalent to 'group
by a, c' '$e = 3' is
- * equivalent to 'group by a'
- *
- * Expanded records: \+-----+-----+-----+-----+-----+ \| a | b | c | d | $e |
- * \+-----+-----+-----+-----+-----+ ---+--- \| 1 | 1 | null| d1 | 1 | |
- * \+-----+-----+-----+-----+-----+ | \| 1 | null| c1 | d1 | 2 | records
expanded by record1
- * \+-----+-----+-----+-----+-----+ | \| 1 | null| null| d1 | 3 | |
\+-----+-----+-----+-----+-----+
- * ---+--- \| 1 | 2 | null| d2 | 1 | | \+-----+-----+-----+-----+-----+ | \| 1
| null| c1 | d2 | 2 |
- * records expanded by record2 \+-----+-----+-----+-----+-----+ | \| 1 | null|
null| d2 | 3 | |
- * \+-----+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| d1 | 1 | |
- * \+-----+-----+-----+-----+-----+ | \| 2 | null| c1 | d1 | 2 | records
expanded by record3
- * \+-----+-----+-----+-----+-----+ | \| 2 | null| null| d1 | 3 | |
\+-----+-----+-----+-----+-----+
- * ---+---
- *
- * Example2 (Some fields are both in DISTINCT aggregates and non-DISTINCT
aggregates):
- *
- * SQL: SELECT MAX(a) as t1, COUNT(DISTINCT a) as t2, count(DISTINCT d) as t3
FROM MyTable
- *
- * Field `a` is both in DISTINCT aggregate and `MAX` aggregate, so, `a` should
be outputted as two
- * individual fields, one is for `MAX` aggregate, another is for DISTINCT
aggregate.
- *
- * Expanded records: \+-----+-----+-----+-----+ \| a | d | $e | a_0 |
\+-----+-----+-----+-----+
- * ---+--- \| 1 | null| 1 | 1 | | \+-----+-----+-----+-----+ | \| null| d1 | 2
| 1 | records
- * expanded by record1 \+-----+-----+-----+-----+ | \| null| null| 3 | 1 | |
- * \+-----+-----+-----+-----+ ---+--- \| 1 | null| 1 | 1 | |
\+-----+-----+-----+-----+ | \| null|
- * d2 | 2 | 1 | records expanded by record2 \+-----+-----+-----+-----+ | \|
null| null| 3 | 1 | |
- * \+-----+-----+-----+-----+ ---+--- \| 2 | null| 1 | 2 | |
\+-----+-----+-----+-----+ | \| null|
- * d1 | 2 | 2 | records expanded by record3 \+-----+-----+-----+-----+ | \|
null| null| 3 | 2 | |
- * \+-----+-----+-----+-----+ ---+---
- *
- * Example3 (expand for CUBE/ROLLUP/GROUPING SETS):
- *
- * SQL: SELECT a, c, SUM(b) as b FROM MyTable GROUP BY GROUPING SETS (a, c)
- *
- * Logical plan:
- * {{{
- * LogicalAggregate(group=[{0, 1}], groups=[[{0}, {1}]], b=[SUM($2)])
- * LogicalProject(a=[$0], c=[$2], b=[$1])
- * LogicalTableScan(table=[[builtin, default, MyTable]])
- * }}}
- *
- * Logical plan after this rule applied:
- * {{{
- * LogicalCalc(expr#0..3=[{inputs}], proj#0..1=[{exprs}], b=[$t3])
- * LogicalAggregate(group=[{0, 2, 3}], groups=[[]], b=[SUM($1)])
- * LogicalExpand(projects=[{a=[$0], b=[$1], c=[null], $e=[1]},
- * {a=[null], b=[$1], c=[$2], $e=[2]}])
- * LogicalNativeTableScan(table=[[builtin, default, MyTable]])
- * }}}
- *
- * '$e = 1' is equivalent to 'group by a' '$e = 2' is equivalent to 'group by
c'
- *
- * Expanded records: \+-----+-----+-----+-----+ \| a | b | c | $e |
\+-----+-----+-----+-----+
- * ---+--- \| 1 | 1 | null| 1 | | \+-----+-----+-----+-----+ records expanded
by record1 \| null| 1
- * \| c1 | 2 | | \+-----+-----+-----+-----+ ---+--- \| 1 | 2 | null| 1 | |
- * \+-----+-----+-----+-----+ records expanded by record2 \| null| 2 | c1 | 2
| |
- * \+-----+-----+-----+-----+ ---+--- \| 2 | 1 | null| 1 | |
\+-----+-----+-----+-----+ records
- * expanded by record3 \| null| 1 | c1 | 2 | | \+-----+-----+-----+-----+
---+---
- */
-class DecomposeGroupingSetsRule
- extends RelOptRule(
- operand(classOf[LogicalAggregate], any),
- FlinkRelFactories.FLINK_REL_BUILDER,
- "DecomposeGroupingSetsRule") {
-
- override def matches(call: RelOptRuleCall): Boolean = {
- val agg: LogicalAggregate = call.rel(0)
- val groupIdExprs = AggregateUtil.getGroupIdExprIndexes(agg.getAggCallList)
- agg.getGroupSets.size() > 1 || groupIdExprs.nonEmpty
- }
-
- override def onMatch(call: RelOptRuleCall): Unit = {
- val agg: LogicalAggregate = call.rel(0)
- // Long data type is used to store groupValue in
FlinkAggregateExpandDistinctAggregatesRule,
- // and the result of grouping function is a positive value,
- // so the max groupCount must be less than 64.
- if (agg.getGroupCount >= 64) {
- throw new TableException("group count must be less than 64.")
- }
-
- val aggInput = agg.getInput
- val groupIdExprs = AggregateUtil.getGroupIdExprIndexes(agg.getAggCallList)
- val aggCallsWithIndexes = agg.getAggCallList.zipWithIndex
-
- val cluster = agg.getCluster
- val rexBuilder = cluster.getRexBuilder
- val needExpand = agg.getGroupSets.size() > 1
-
- val relBuilder = call.builder().asInstanceOf[FlinkRelBuilder]
- relBuilder.push(aggInput)
-
- val (newGroupSet, duplicateFieldMap) = if (needExpand) {
- val (duplicateFieldMap, expandIdIdxInExpand) =
ExpandUtil.buildExpandNode(
- relBuilder,
- agg.getAggCallList,
- agg.getGroupSet,
- agg.getGroupSets)
-
- // new groupSet contains original groupSet and expand_id('$e') field
- val newGroupSet =
agg.getGroupSet.union(ImmutableBitSet.of(expandIdIdxInExpand))
-
- (newGroupSet, duplicateFieldMap)
- } else {
- // no need add expand node, only need care about group functions
- (agg.getGroupSet, Map.empty[Integer, Integer])
- }
-
- val newGroupCount = newGroupSet.cardinality()
- val newAggCalls = aggCallsWithIndexes.collect {
- case (aggCall, idx) if !groupIdExprs.contains(idx) =>
- val newArgList = aggCall.getArgList.map(a =>
duplicateFieldMap.getOrElse(a, a)).toList
- val newFilterArg = duplicateFieldMap.getOrDefault(aggCall.filterArg,
aggCall.filterArg)
- aggCall.adaptTo(
- relBuilder.peek(),
- newArgList,
- newFilterArg,
- agg.getGroupCount,
- newGroupCount)
- }
-
- // create simple aggregate
- relBuilder.aggregate(
- relBuilder.groupKey(newGroupSet,
ImmutableList.of[ImmutableBitSet](newGroupSet)),
- newAggCalls)
- val newAgg = relBuilder.peek()
-
- // create a project to mapping original aggregate's output
- // get names of original grouping fields
- val groupingFieldsName = Seq
- .range(0, agg.getGroupCount)
- .map(x => agg.getRowType.getFieldNames.get(x))
-
- // create field access for all original grouping fields
- val groupingFields = agg.getGroupSet.toList.zipWithIndex
- .map { case (_, idx) => rexBuilder.makeInputRef(newAgg, idx) }
- .toArray[RexNode]
-
- val groupSetsWithIndexes = agg.getGroupSets.zipWithIndex
- // output aggregate calls including `normal` agg call and grouping agg call
- var aggCnt = 0
- val aggFields = aggCallsWithIndexes.map {
- case (aggCall, idx) if groupIdExprs.contains(idx) =>
- if (needExpand) {
- // reference to expand_id('$e') field in new aggregate
- val expandIdIdxInNewAgg = newGroupCount - 1
- val expandIdField = rexBuilder.makeInputRef(newAgg,
expandIdIdxInNewAgg)
- // create case when for group expression
- val whenThenElse = groupSetsWithIndexes.flatMap {
- case (subGroupSet, i) =>
- val groupExpr = lowerGroupExpr(rexBuilder, aggCall,
groupSetsWithIndexes, i)
- if (i < agg.getGroupSets.size() - 1) {
- // WHEN/THEN
- val expandIdVal = ExpandUtil.genExpandId(agg.getGroupSet,
subGroupSet)
- val expandIdType =
newAgg.getRowType.getFieldList.get(expandIdIdxInNewAgg).getType
- val expandIdLit = rexBuilder.makeLiteral(expandIdVal,
expandIdType, false)
- Seq(
- // when $e = $e_value
- rexBuilder.makeCall(SqlStdOperatorTable.EQUALS,
expandIdField, expandIdLit),
- // then return group expression literal value
- groupExpr
- )
- } else {
- // ELSE
- Seq(
- // else return group expression literal value
- groupExpr
- )
- }
- }
- rexBuilder.makeCall(SqlStdOperatorTable.CASE, whenThenElse)
- } else {
- // create literal for group expression
- lowerGroupExpr(rexBuilder, aggCall, groupSetsWithIndexes, 0)
- }
- case _ =>
- // create access to aggregation result
- val aggResult = rexBuilder.makeInputRef(newAgg, newGroupCount + aggCnt)
- aggCnt += 1
- aggResult
- }
-
- // add a projection to establish the result schema and set the values of
the group expressions.
- relBuilder.project(
- groupingFields.toSeq ++ aggFields,
- groupingFieldsName ++ agg.getAggCallList.map(_.name))
- relBuilder.convert(agg.getRowType, true)
-
- call.transformTo(relBuilder.build())
- }
-
- /** Returns a literal for a given group expression. */
- private def lowerGroupExpr(
- builder: RexBuilder,
- call: AggregateCall,
- groupSetsWithIndexes: Seq[(ImmutableBitSet, Int)],
- indexInGroupSets: Int): RexNode = {
-
- val groupSet = groupSetsWithIndexes(indexInGroupSets)._1
- val groups = groupSet.asSet()
- call.getAggregation.getKind match {
- case SqlKind.GROUP_ID =>
- // https://issues.apache.org/jira/browse/CALCITE-1824
- // GROUP_ID is not in the SQL standard. It is implemented only by
Oracle.
- // GROUP_ID is useful only if you have duplicate grouping sets,
- // If grouping sets are distinct, GROUP_ID() will always return zero;
- // Else return the index in the duplicate grouping sets.
- // e.g. SELECT deptno, GROUP_ID() AS g FROM Emp GROUP BY GROUPING SETS
(deptno, (), ())
- // As you can see, the grouping set () occurs twice.
- // So there is one row in the result for each occurrence:
- // the first occurrence has g = 0; the second has g = 1.
- val duplicateGroupSetsIndices = groupSetsWithIndexes
- .filter { case (gs, _) => gs.compareTo(groupSet) == 0 }
- .map(_._2)
- .toArray[Int]
- require(duplicateGroupSetsIndices.nonEmpty)
- val id: Long = duplicateGroupSetsIndices.indexOf(indexInGroupSets)
- builder.makeLiteral(id, call.getType, false)
- case SqlKind.GROUPING | SqlKind.GROUPING_ID =>
- // GROUPING function is defined in the SQL standard,
- // but the definition of GROUPING is different from in Oracle and in
SQL standard:
- //
https://docs.oracle.com/cd/B28359_01/server.111/b28286/functions064.htm#SQLRF00647
- //
- // GROUPING_ID function is not defined in the SQL standard, and has
the same
- // functionality with GROUPING function in Calcite.
- // our implementation is consistent with Oracle about GROUPING_ID
function.
- //
- // NOTES:
- // In Calcite, the java-document of SqlGroupingFunction is not
consistent with agg.iq.
- val res: Long = call.getArgList.foldLeft(0L)(
- (res, arg) => (res << 1L) + (if (groups.contains(arg)) 0L else 1L))
- builder.makeLiteral(res, call.getType, false)
- case _ => builder.makeNullLiteral(call.getType)
- }
- }
-}
-
-object DecomposeGroupingSetsRule {
- val INSTANCE: RelOptRule = new DecomposeGroupingSetsRule
-}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala
index 84ea3d15d8d..8e89af81f5e 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/utils/JavaScalaConversionUtil.scala
@@ -19,7 +19,7 @@ package org.apache.flink.table.planner.utils
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
-import java.util.{List => JList, Optional, Set => JSet}
+import java.util.{List => JList, Map => JMap, Optional, Set => JSet}
import java.util.function.{BiConsumer, Consumer, Function}
import scala.collection.JavaConverters._
@@ -59,4 +59,10 @@ object JavaScalaConversionUtil {
def toJava(set: Set[Int]): JSet[Integer] =
set.map(_.asInstanceOf[Integer]).asJava
+
+ def toJava[K, V](map: scala.collection.Map[K, V]): JMap[K, V] =
+ map.asJava
+
+ def toScala[K, V](map: JMap[K, V]): scala.collection.Map[K, V] =
+ map.asScala
}