This is an automated email from the ASF dual-hosted git repository.
snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 83ce27871b3 [FLINK-37009][table] Migrate `PruneAggregateCallRule` to
java
83ce27871b3 is described below
commit 83ce27871b30e44a469792cd29755afe1c37ba09
Author: Jacky Lau <[email protected]>
AuthorDate: Sat Jan 10 19:09:01 2026 +0800
[FLINK-37009][table] Migrate `PruneAggregateCallRule` to java
---------
Co-authored-by: yongliu <[email protected]>
---
.../plan/rules/logical/PruneAggregateCallRule.java | 292 +++++++++++++++++++++
.../rules/logical/PruneAggregateCallRule.scala | 200 --------------
2 files changed, 292 insertions(+), 200 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
new file mode 100644
index 00000000000..05e84012c99
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.util.Preconditions;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Aggregate.Group;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Calc;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.runtime.Utilities;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.Mappings;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Planner rule that removes unreferenced AggregateCall from Aggregate. */
+public abstract class PruneAggregateCallRule<T extends RelNode>
+ extends RelRule<PruneAggregateCallRule.PruneAggregateCallRuleConfig> {
+
+ public static final ProjectPruneAggregateCallRule PROJECT_ON_AGGREGATE =
+
ProjectPruneAggregateCallRule.ProjectPruneAggregateCallRuleConfig.DEFAULT.toRule();
+ public static final CalcPruneAggregateCallRule CALC_ON_AGGREGATE =
+
CalcPruneAggregateCallRule.CalcPruneAggregateCallRuleConfig.DEFAULT.toRule();
+
+ protected
PruneAggregateCallRule(PruneAggregateCallRule.PruneAggregateCallRuleConfig
config) {
+ super(config);
+ }
+
+ protected abstract ImmutableBitSet getInputRefs(T relOnAgg);
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ T relOnAgg = call.rel(0);
+ Aggregate agg = call.rel(1);
+ if (agg.getGroupType() != Group.SIMPLE
+ || agg.getAggCallList().isEmpty()
+ ||
+ // at least output one column
+ (agg.getGroupCount() == 0 && agg.getAggCallList().size() ==
1)) {
+ return false;
+ }
+ ImmutableBitSet inputRefs = getInputRefs(relOnAgg);
+ int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg);
+ return unrefAggCallIndices.length > 0;
+ }
+
+ private int[] getUnrefAggCallIndices(ImmutableBitSet inputRefs, Aggregate
agg) {
+ int groupCount = agg.getGroupCount();
+ return IntStream.range(0, agg.getAggCallList().size())
+ .filter(index -> !inputRefs.get(groupCount + index))
+ .toArray();
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ T relOnAgg = call.rel(0);
+ Aggregate agg = call.rel(1);
+ ImmutableBitSet inputRefs = getInputRefs(relOnAgg);
+ int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg);
+ Preconditions.checkArgument(unrefAggCallIndices.length > 0,
"requirement failed");
+
+ List<AggregateCall> newAggCalls = new
ArrayList<>(agg.getAggCallList());
+ // remove unreferenced AggCall from original aggCalls
+ Arrays.stream(unrefAggCallIndices)
+ .boxed()
+ .sorted(Comparator.reverseOrder())
+ // we need this int cast here. because it will remove the
value instead of removing
+ // index when doesn't have int cast
+ .forEach(index -> newAggCalls.remove((int) index));
+
+ if (newAggCalls.isEmpty() && agg.getGroupCount() == 0) {
+ // at least output one column
+ newAggCalls.add(agg.getAggCallList().get(0));
+ unrefAggCallIndices =
+ Arrays.copyOfRange(unrefAggCallIndices, 1,
unrefAggCallIndices.length);
+ }
+
+ Aggregate newAgg =
+ agg.copy(
+ agg.getTraitSet(),
+ agg.getInput(),
+ agg.getGroupSet(),
+ List.of(agg.getGroupSet()),
+ newAggCalls);
+
+ int newFieldIndex = 0;
+ // map old agg output index to new agg output index
+ Map<Integer, Integer> mapOldToNew = new HashMap<>();
+ int fieldCountOfOldAgg = agg.getRowType().getFieldCount();
+ List<Integer> unrefAggCallOutputIndices =
+ Arrays.stream(unrefAggCallIndices)
+ .mapToObj(i -> i + agg.getGroupCount())
+ .collect(Collectors.toList());
+ for (int i = 0; i < fieldCountOfOldAgg; i++) {
+ if (!unrefAggCallOutputIndices.contains(i)) {
+ mapOldToNew.put(i, newFieldIndex);
+ newFieldIndex++;
+ }
+ }
+ Preconditions.checkArgument(
+ mapOldToNew.size() == newAgg.getRowType().getFieldCount(),
"requirement failed");
+
+ Mappings.TargetMapping mapping =
+ Mappings.target(
+ mapOldToNew, fieldCountOfOldAgg,
newAgg.getRowType().getFieldCount());
+ RelNode newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg);
+ call.transformTo(newRelOnAgg);
+ }
+
+ protected abstract RelNode createNewRel(
+ Mappings.TargetMapping mapping, T project, RelNode newAgg);
+
+ public static class ProjectPruneAggregateCallRule extends
PruneAggregateCallRule<Project> {
+
+ protected
ProjectPruneAggregateCallRule(ProjectPruneAggregateCallRuleConfig config) {
+ super(config);
+ }
+
+ @Override
+ protected ImmutableBitSet getInputRefs(Project relOnAgg) {
+ return RelOptUtil.InputFinder.bits(relOnAgg.getProjects(), null);
+ }
+
+ @Override
+ protected RelNode createNewRel(
+ Mappings.TargetMapping mapping, Project project, RelNode
newAgg) {
+ List<RexNode> newProjects = RexUtil.apply(mapping,
project.getProjects());
+ if (projectsOnlyIdentity(newProjects,
newAgg.getRowType().getFieldCount())
+ && Utilities.compare(
+ project.getRowType().getFieldNames(),
+ newAgg.getRowType().getFieldNames())
+ == 0) {
+ return newAgg;
+ } else {
+ return project.copy(
+ project.getTraitSet(), newAgg, newProjects,
project.getRowType());
+ }
+ }
+
+ private boolean projectsOnlyIdentity(List<RexNode> projects, int
inputFieldCount) {
+ if (projects.size() != inputFieldCount) {
+ return false;
+ }
+ return IntStream.range(0, projects.size())
+ .allMatch(
+ index -> {
+ RexNode project = projects.get(index);
+ if (project instanceof RexInputRef) {
+ RexInputRef r = (RexInputRef) project;
+ return r.getIndex() == index;
+ }
+ return false;
+ });
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable(singleton = false)
+ public interface ProjectPruneAggregateCallRuleConfig
+ extends PruneAggregateCallRule.PruneAggregateCallRuleConfig {
+ ProjectPruneAggregateCallRuleConfig DEFAULT =
+ ImmutableProjectPruneAggregateCallRuleConfig.builder()
+ .build()
+ .withOperandSupplier(
+ b0 ->
+ b0.operand(Project.class)
+ .oneInput(
+ b1 ->
+
b1.operand(Aggregate.class)
+
.anyInputs()))
+ .withDescription(
+ "PruneAggregateCallRule_" +
Project.class.getCanonicalName());
+
+ @Override
+ default ProjectPruneAggregateCallRule toRule() {
+ return new ProjectPruneAggregateCallRule(this);
+ }
+ }
+ }
+
+ public static class CalcPruneAggregateCallRule extends
PruneAggregateCallRule<Calc> {
+
+ protected CalcPruneAggregateCallRule(CalcPruneAggregateCallRuleConfig
config) {
+ super(config);
+ }
+
+ @Override
+ protected ImmutableBitSet getInputRefs(Calc calc) {
+ RexProgram program = calc.getProgram();
+ RexNode condition =
+ program.getCondition() != null
+ ? program.expandLocalRef(program.getCondition())
+ : null;
+ List<RexNode> projects =
+ program.getProjectList().stream()
+ .map(program::expandLocalRef)
+ .collect(Collectors.toList());
+ return RelOptUtil.InputFinder.bits(projects, condition);
+ }
+
+ @Override
+ protected RelNode createNewRel(Mappings.TargetMapping mapping, Calc
calc, RelNode newAgg) {
+ RexProgram program = calc.getProgram();
+ RexNode newCondition =
+ program.getCondition() != null
+ ? RexUtil.apply(mapping,
program.expandLocalRef(program.getCondition()))
+ : null;
+ List<RexNode> projects =
+ program.getProjectList().stream()
+ .map(program::expandLocalRef)
+ .collect(Collectors.toList());
+ List<RexNode> newProjects = RexUtil.apply(mapping, projects);
+ RexProgram newProgram =
+ RexProgram.create(
+ newAgg.getRowType(),
+ newProjects,
+ newCondition,
+ program.getOutputRowType().getFieldNames(),
+ calc.getCluster().getRexBuilder());
+ if (newProgram.isTrivial()
+ && Utilities.compare(
+ calc.getRowType().getFieldNames(),
+ newAgg.getRowType().getFieldNames())
+ == 0) {
+ return newAgg;
+ } else {
+ return calc.copy(calc.getTraitSet(), newAgg, newProgram);
+ }
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable(singleton = false)
+ public interface CalcPruneAggregateCallRuleConfig extends
PruneAggregateCallRuleConfig {
+ CalcPruneAggregateCallRuleConfig DEFAULT =
+ ImmutableCalcPruneAggregateCallRuleConfig.builder()
+ .build()
+ .withOperandSupplier(
+ b0 ->
+ b0.operand(Calc.class)
+ .oneInput(
+ b1 ->
+
b1.operand(Aggregate.class)
+
.anyInputs()))
+ .withDescription(
+ "PruneAggregateCallRule_" +
Calc.class.getCanonicalName());
+
+ @Override
+ default CalcPruneAggregateCallRule toRule() {
+ return new CalcPruneAggregateCallRule(this);
+ }
+ }
+ }
+
+ /** Rule configuration. */
+ public interface PruneAggregateCallRuleConfig extends RelRule.Config {
+ @Override
+ PruneAggregateCallRule toRule();
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
deleted file mode 100644
index bd7c479fea3..00000000000
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.rules.logical
-
-import com.google.common.collect.{ImmutableList, Maps}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
-import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, Project,
RelFactories}
-import org.apache.calcite.rel.core.Aggregate.Group
-import org.apache.calcite.rex.{RexInputRef, RexNode, RexProgram, RexUtil}
-import org.apache.calcite.runtime.Utilities
-import org.apache.calcite.util.ImmutableBitSet
-import org.apache.calcite.util.mapping.Mappings
-
-import java.util
-
-import scala.collection.JavaConversions._
-
-/** Planner rule that removes unreferenced AggregateCall from Aggregate */
-abstract class PruneAggregateCallRule[T <: RelNode](topClass: Class[T])
- extends RelOptRule(
- operand(topClass, operand(classOf[Aggregate], any)),
- RelFactories.LOGICAL_BUILDER,
- s"PruneAggregateCallRule_${topClass.getCanonicalName}") {
-
- protected def getInputRefs(relOnAgg: T): ImmutableBitSet
-
- override def matches(call: RelOptRuleCall): Boolean = {
- val relOnAgg: T = call.rel(0)
- val agg: Aggregate = call.rel(1)
- if (
- agg.getGroupType != Group.SIMPLE || agg.getAggCallList.isEmpty ||
- // at least output one column
- (agg.getGroupCount == 0 && agg.getAggCallList.size() == 1)
- ) {
- return false
- }
- val inputRefs = getInputRefs(relOnAgg)
- val unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg)
- unrefAggCallIndices.nonEmpty
- }
-
- private def getUnrefAggCallIndices(inputRefs: ImmutableBitSet, agg:
Aggregate): Array[Int] = {
- val groupCount = agg.getGroupCount
- agg.getAggCallList.indices
- .flatMap {
- index =>
- val aggCallOutputIndex = groupCount + index
- if (inputRefs.get(aggCallOutputIndex)) {
- Array.empty[Int]
- } else {
- Array(index)
- }
- }
- .toArray[Int]
- }
-
- override def onMatch(call: RelOptRuleCall): Unit = {
- val relOnAgg: T = call.rel(0)
- val agg: Aggregate = call.rel(1)
- val inputRefs = getInputRefs(relOnAgg)
- var unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg)
- require(unrefAggCallIndices.nonEmpty)
-
- val newAggCalls: util.List[AggregateCall] = new
util.ArrayList(agg.getAggCallList)
- // remove unreferenced AggCall from original aggCalls
- unrefAggCallIndices.sorted.reverse.foreach(i => newAggCalls.remove(i))
-
- if (newAggCalls.isEmpty && agg.getGroupCount == 0) {
- // at least output one column
- newAggCalls.add(agg.getAggCallList.get(0))
- unrefAggCallIndices = unrefAggCallIndices.slice(1,
unrefAggCallIndices.length)
- }
-
- val newAgg = agg.copy(
- agg.getTraitSet,
- agg.getInput,
- agg.getGroupSet,
- ImmutableList.of(agg.getGroupSet),
- newAggCalls
- )
-
- var newFieldIndex = 0
- // map old agg output index to new agg output index
- val mapOldToNew = Maps.newHashMap[Integer, Integer]()
- val fieldCountOfOldAgg = agg.getRowType.getFieldCount
- val unrefAggCallOutputIndices = unrefAggCallIndices.map(_ +
agg.getGroupCount)
- (0 until fieldCountOfOldAgg).foreach {
- i =>
- if (!unrefAggCallOutputIndices.contains(i)) {
- mapOldToNew.put(i, newFieldIndex)
- newFieldIndex += 1
- }
- }
- require(mapOldToNew.size() == newAgg.getRowType.getFieldCount)
-
- val mapping = Mappings.target(mapOldToNew, fieldCountOfOldAgg,
newAgg.getRowType.getFieldCount)
- val newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg)
- call.transformTo(newRelOnAgg)
- }
-
- protected def createNewRel(mapping: Mappings.TargetMapping, project: T,
newAgg: RelNode): RelNode
-}
-
-class ProjectPruneAggregateCallRule extends
PruneAggregateCallRule(classOf[Project]) {
- override protected def getInputRefs(relOnAgg: Project): ImmutableBitSet = {
- RelOptUtil.InputFinder.bits(relOnAgg.getProjects, null)
- }
-
- override protected def createNewRel(
- mapping: Mappings.TargetMapping,
- project: Project,
- newAgg: RelNode): RelNode = {
- val newProjects = RexUtil.apply(mapping, project.getProjects).toList
- if (
- projectsOnlyIdentity(newProjects, newAgg.getRowType.getFieldCount) &&
- Utilities.compare(project.getRowType.getFieldNames,
newAgg.getRowType.getFieldNames) == 0
- ) {
- newAgg
- } else {
- project.copy(project.getTraitSet, newAgg, newProjects,
project.getRowType)
- }
- }
-
- private def projectsOnlyIdentity(projects: util.List[RexNode],
inputFieldCount: Int): Boolean = {
- if (projects.size != inputFieldCount) {
- return false
- }
- projects.zipWithIndex.forall {
- case (project, index) =>
- project match {
- case r: RexInputRef => r.getIndex == index
- case _ => false
- }
- }
- }
-}
-
-class CalcPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Calc])
{
- override protected def getInputRefs(relOnAgg: Calc): ImmutableBitSet = {
- val program = relOnAgg.getProgram
- val condition = if (program.getCondition != null) {
- program.expandLocalRef(program.getCondition)
- } else {
- null
- }
- val projects = program.getProjectList.map(program.expandLocalRef)
- RelOptUtil.InputFinder.bits(projects, condition)
- }
-
- override protected def createNewRel(
- mapping: Mappings.TargetMapping,
- calc: Calc,
- newAgg: RelNode): RelNode = {
- val program = calc.getProgram
- val newCondition = if (program.getCondition != null) {
- RexUtil.apply(mapping, program.expandLocalRef(program.getCondition))
- } else {
- null
- }
- val projects = program.getProjectList.map(program.expandLocalRef)
- val newProjects = RexUtil.apply(mapping, projects).toList
- val newProgram = RexProgram.create(
- newAgg.getRowType,
- newProjects,
- newCondition,
- program.getOutputRowType.getFieldNames,
- calc.getCluster.getRexBuilder
- )
- if (
- newProgram.isTrivial &&
- Utilities.compare(calc.getRowType.getFieldNames,
newAgg.getRowType.getFieldNames) == 0
- ) {
- newAgg
- } else {
- calc.copy(calc.getTraitSet, newAgg, newProgram)
- }
- }
-}
-
-object PruneAggregateCallRule {
- val PROJECT_ON_AGGREGATE = new ProjectPruneAggregateCallRule
- val CALC_ON_AGGREGATE = new CalcPruneAggregateCallRule
-}