This is an automated email from the ASF dual-hosted git repository.

snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 83ce27871b3 [FLINK-37009][table] Migrate `PruneAggregateCallRule` to 
java
83ce27871b3 is described below

commit 83ce27871b30e44a469792cd29755afe1c37ba09
Author: Jacky Lau <[email protected]>
AuthorDate: Sat Jan 10 19:09:01 2026 +0800

    [FLINK-37009][table] Migrate `PruneAggregateCallRule` to java
    
    
    
    ---------
    
    Co-authored-by: yongliu <[email protected]>
---
 .../plan/rules/logical/PruneAggregateCallRule.java | 292 +++++++++++++++++++++
 .../rules/logical/PruneAggregateCallRule.scala     | 200 --------------
 2 files changed, 292 insertions(+), 200 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
new file mode 100644
index 00000000000..05e84012c99
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.util.Preconditions;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Aggregate.Group;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Calc;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.runtime.Utilities;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.Mappings;
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Planner rule that removes unreferenced AggregateCall from Aggregate. */
+public abstract class PruneAggregateCallRule<T extends RelNode>
+        extends RelRule<PruneAggregateCallRule.PruneAggregateCallRuleConfig> {
+
+    public static final ProjectPruneAggregateCallRule PROJECT_ON_AGGREGATE =
+            
ProjectPruneAggregateCallRule.ProjectPruneAggregateCallRuleConfig.DEFAULT.toRule();
+    public static final CalcPruneAggregateCallRule CALC_ON_AGGREGATE =
+            
CalcPruneAggregateCallRule.CalcPruneAggregateCallRuleConfig.DEFAULT.toRule();
+
+    protected 
PruneAggregateCallRule(PruneAggregateCallRule.PruneAggregateCallRuleConfig 
config) {
+        super(config);
+    }
+
+    protected abstract ImmutableBitSet getInputRefs(T relOnAgg);
+
+    @Override
+    public boolean matches(RelOptRuleCall call) {
+        T relOnAgg = call.rel(0);
+        Aggregate agg = call.rel(1);
+        if (agg.getGroupType() != Group.SIMPLE
+                || agg.getAggCallList().isEmpty()
+                ||
+                // at least output one column
+                (agg.getGroupCount() == 0 && agg.getAggCallList().size() == 
1)) {
+            return false;
+        }
+        ImmutableBitSet inputRefs = getInputRefs(relOnAgg);
+        int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg);
+        return unrefAggCallIndices.length > 0;
+    }
+
+    private int[] getUnrefAggCallIndices(ImmutableBitSet inputRefs, Aggregate 
agg) {
+        int groupCount = agg.getGroupCount();
+        return IntStream.range(0, agg.getAggCallList().size())
+                .filter(index -> !inputRefs.get(groupCount + index))
+                .toArray();
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        T relOnAgg = call.rel(0);
+        Aggregate agg = call.rel(1);
+        ImmutableBitSet inputRefs = getInputRefs(relOnAgg);
+        int[] unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg);
+        Preconditions.checkArgument(unrefAggCallIndices.length > 0, 
"requirement failed");
+
+        List<AggregateCall> newAggCalls = new 
ArrayList<>(agg.getAggCallList());
+        // remove unreferenced AggCall from original aggCalls
+        Arrays.stream(unrefAggCallIndices)
+                .boxed()
+                .sorted(Comparator.reverseOrder())
+                // we need this int cast here. because it will remove the 
value instead of removing
+                // index when doesn't have int cast
+                .forEach(index -> newAggCalls.remove((int) index));
+
+        if (newAggCalls.isEmpty() && agg.getGroupCount() == 0) {
+            // at least output one column
+            newAggCalls.add(agg.getAggCallList().get(0));
+            unrefAggCallIndices =
+                    Arrays.copyOfRange(unrefAggCallIndices, 1, 
unrefAggCallIndices.length);
+        }
+
+        Aggregate newAgg =
+                agg.copy(
+                        agg.getTraitSet(),
+                        agg.getInput(),
+                        agg.getGroupSet(),
+                        List.of(agg.getGroupSet()),
+                        newAggCalls);
+
+        int newFieldIndex = 0;
+        // map old agg output index to new agg output index
+        Map<Integer, Integer> mapOldToNew = new HashMap<>();
+        int fieldCountOfOldAgg = agg.getRowType().getFieldCount();
+        List<Integer> unrefAggCallOutputIndices =
+                Arrays.stream(unrefAggCallIndices)
+                        .mapToObj(i -> i + agg.getGroupCount())
+                        .collect(Collectors.toList());
+        for (int i = 0; i < fieldCountOfOldAgg; i++) {
+            if (!unrefAggCallOutputIndices.contains(i)) {
+                mapOldToNew.put(i, newFieldIndex);
+                newFieldIndex++;
+            }
+        }
+        Preconditions.checkArgument(
+                mapOldToNew.size() == newAgg.getRowType().getFieldCount(), 
"requirement failed");
+
+        Mappings.TargetMapping mapping =
+                Mappings.target(
+                        mapOldToNew, fieldCountOfOldAgg, 
newAgg.getRowType().getFieldCount());
+        RelNode newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg);
+        call.transformTo(newRelOnAgg);
+    }
+
+    protected abstract RelNode createNewRel(
+            Mappings.TargetMapping mapping, T project, RelNode newAgg);
+
+    public static class ProjectPruneAggregateCallRule extends 
PruneAggregateCallRule<Project> {
+
+        protected 
ProjectPruneAggregateCallRule(ProjectPruneAggregateCallRuleConfig config) {
+            super(config);
+        }
+
+        @Override
+        protected ImmutableBitSet getInputRefs(Project relOnAgg) {
+            return RelOptUtil.InputFinder.bits(relOnAgg.getProjects(), null);
+        }
+
+        @Override
+        protected RelNode createNewRel(
+                Mappings.TargetMapping mapping, Project project, RelNode 
newAgg) {
+            List<RexNode> newProjects = RexUtil.apply(mapping, 
project.getProjects());
+            if (projectsOnlyIdentity(newProjects, 
newAgg.getRowType().getFieldCount())
+                    && Utilities.compare(
+                                    project.getRowType().getFieldNames(),
+                                    newAgg.getRowType().getFieldNames())
+                            == 0) {
+                return newAgg;
+            } else {
+                return project.copy(
+                        project.getTraitSet(), newAgg, newProjects, 
project.getRowType());
+            }
+        }
+
+        private boolean projectsOnlyIdentity(List<RexNode> projects, int 
inputFieldCount) {
+            if (projects.size() != inputFieldCount) {
+                return false;
+            }
+            return IntStream.range(0, projects.size())
+                    .allMatch(
+                            index -> {
+                                RexNode project = projects.get(index);
+                                if (project instanceof RexInputRef) {
+                                    RexInputRef r = (RexInputRef) project;
+                                    return r.getIndex() == index;
+                                }
+                                return false;
+                            });
+        }
+
+        /** Rule configuration. */
+        @Value.Immutable(singleton = false)
+        public interface ProjectPruneAggregateCallRuleConfig
+                extends PruneAggregateCallRule.PruneAggregateCallRuleConfig {
+            ProjectPruneAggregateCallRuleConfig DEFAULT =
+                    ImmutableProjectPruneAggregateCallRuleConfig.builder()
+                            .build()
+                            .withOperandSupplier(
+                                    b0 ->
+                                            b0.operand(Project.class)
+                                                    .oneInput(
+                                                            b1 ->
+                                                                    
b1.operand(Aggregate.class)
+                                                                            
.anyInputs()))
+                            .withDescription(
+                                    "PruneAggregateCallRule_" + 
Project.class.getCanonicalName());
+
+            @Override
+            default ProjectPruneAggregateCallRule toRule() {
+                return new ProjectPruneAggregateCallRule(this);
+            }
+        }
+    }
+
+    public static class CalcPruneAggregateCallRule extends 
PruneAggregateCallRule<Calc> {
+
+        protected CalcPruneAggregateCallRule(CalcPruneAggregateCallRuleConfig 
config) {
+            super(config);
+        }
+
+        @Override
+        protected ImmutableBitSet getInputRefs(Calc calc) {
+            RexProgram program = calc.getProgram();
+            RexNode condition =
+                    program.getCondition() != null
+                            ? program.expandLocalRef(program.getCondition())
+                            : null;
+            List<RexNode> projects =
+                    program.getProjectList().stream()
+                            .map(program::expandLocalRef)
+                            .collect(Collectors.toList());
+            return RelOptUtil.InputFinder.bits(projects, condition);
+        }
+
+        @Override
+        protected RelNode createNewRel(Mappings.TargetMapping mapping, Calc 
calc, RelNode newAgg) {
+            RexProgram program = calc.getProgram();
+            RexNode newCondition =
+                    program.getCondition() != null
+                            ? RexUtil.apply(mapping, 
program.expandLocalRef(program.getCondition()))
+                            : null;
+            List<RexNode> projects =
+                    program.getProjectList().stream()
+                            .map(program::expandLocalRef)
+                            .collect(Collectors.toList());
+            List<RexNode> newProjects = RexUtil.apply(mapping, projects);
+            RexProgram newProgram =
+                    RexProgram.create(
+                            newAgg.getRowType(),
+                            newProjects,
+                            newCondition,
+                            program.getOutputRowType().getFieldNames(),
+                            calc.getCluster().getRexBuilder());
+            if (newProgram.isTrivial()
+                    && Utilities.compare(
+                                    calc.getRowType().getFieldNames(),
+                                    newAgg.getRowType().getFieldNames())
+                            == 0) {
+                return newAgg;
+            } else {
+                return calc.copy(calc.getTraitSet(), newAgg, newProgram);
+            }
+        }
+
+        /** Rule configuration. */
+        @Value.Immutable(singleton = false)
+        public interface CalcPruneAggregateCallRuleConfig extends 
PruneAggregateCallRuleConfig {
+            CalcPruneAggregateCallRuleConfig DEFAULT =
+                    ImmutableCalcPruneAggregateCallRuleConfig.builder()
+                            .build()
+                            .withOperandSupplier(
+                                    b0 ->
+                                            b0.operand(Calc.class)
+                                                    .oneInput(
+                                                            b1 ->
+                                                                    
b1.operand(Aggregate.class)
+                                                                            
.anyInputs()))
+                            .withDescription(
+                                    "PruneAggregateCallRule_" + 
Calc.class.getCanonicalName());
+
+            @Override
+            default CalcPruneAggregateCallRule toRule() {
+                return new CalcPruneAggregateCallRule(this);
+            }
+        }
+    }
+
+    /** Rule configuration. */
+    public interface PruneAggregateCallRuleConfig extends RelRule.Config {
+        @Override
+        PruneAggregateCallRule toRule();
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
deleted file mode 100644
index bd7c479fea3..00000000000
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PruneAggregateCallRule.scala
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.rules.logical
-
-import com.google.common.collect.{ImmutableList, Maps}
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
-import org.apache.calcite.rel.RelNode
-import org.apache.calcite.rel.core.{Aggregate, AggregateCall, Calc, Project, 
RelFactories}
-import org.apache.calcite.rel.core.Aggregate.Group
-import org.apache.calcite.rex.{RexInputRef, RexNode, RexProgram, RexUtil}
-import org.apache.calcite.runtime.Utilities
-import org.apache.calcite.util.ImmutableBitSet
-import org.apache.calcite.util.mapping.Mappings
-
-import java.util
-
-import scala.collection.JavaConversions._
-
-/** Planner rule that removes unreferenced AggregateCall from Aggregate */
-abstract class PruneAggregateCallRule[T <: RelNode](topClass: Class[T])
-  extends RelOptRule(
-    operand(topClass, operand(classOf[Aggregate], any)),
-    RelFactories.LOGICAL_BUILDER,
-    s"PruneAggregateCallRule_${topClass.getCanonicalName}") {
-
-  protected def getInputRefs(relOnAgg: T): ImmutableBitSet
-
-  override def matches(call: RelOptRuleCall): Boolean = {
-    val relOnAgg: T = call.rel(0)
-    val agg: Aggregate = call.rel(1)
-    if (
-      agg.getGroupType != Group.SIMPLE || agg.getAggCallList.isEmpty ||
-      // at least output one column
-      (agg.getGroupCount == 0 && agg.getAggCallList.size() == 1)
-    ) {
-      return false
-    }
-    val inputRefs = getInputRefs(relOnAgg)
-    val unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg)
-    unrefAggCallIndices.nonEmpty
-  }
-
-  private def getUnrefAggCallIndices(inputRefs: ImmutableBitSet, agg: 
Aggregate): Array[Int] = {
-    val groupCount = agg.getGroupCount
-    agg.getAggCallList.indices
-      .flatMap {
-        index =>
-          val aggCallOutputIndex = groupCount + index
-          if (inputRefs.get(aggCallOutputIndex)) {
-            Array.empty[Int]
-          } else {
-            Array(index)
-          }
-      }
-      .toArray[Int]
-  }
-
-  override def onMatch(call: RelOptRuleCall): Unit = {
-    val relOnAgg: T = call.rel(0)
-    val agg: Aggregate = call.rel(1)
-    val inputRefs = getInputRefs(relOnAgg)
-    var unrefAggCallIndices = getUnrefAggCallIndices(inputRefs, agg)
-    require(unrefAggCallIndices.nonEmpty)
-
-    val newAggCalls: util.List[AggregateCall] = new 
util.ArrayList(agg.getAggCallList)
-    // remove unreferenced AggCall from original aggCalls
-    unrefAggCallIndices.sorted.reverse.foreach(i => newAggCalls.remove(i))
-
-    if (newAggCalls.isEmpty && agg.getGroupCount == 0) {
-      // at least output one column
-      newAggCalls.add(agg.getAggCallList.get(0))
-      unrefAggCallIndices = unrefAggCallIndices.slice(1, 
unrefAggCallIndices.length)
-    }
-
-    val newAgg = agg.copy(
-      agg.getTraitSet,
-      agg.getInput,
-      agg.getGroupSet,
-      ImmutableList.of(agg.getGroupSet),
-      newAggCalls
-    )
-
-    var newFieldIndex = 0
-    // map old agg output index to new agg output index
-    val mapOldToNew = Maps.newHashMap[Integer, Integer]()
-    val fieldCountOfOldAgg = agg.getRowType.getFieldCount
-    val unrefAggCallOutputIndices = unrefAggCallIndices.map(_ + 
agg.getGroupCount)
-    (0 until fieldCountOfOldAgg).foreach {
-      i =>
-        if (!unrefAggCallOutputIndices.contains(i)) {
-          mapOldToNew.put(i, newFieldIndex)
-          newFieldIndex += 1
-        }
-    }
-    require(mapOldToNew.size() == newAgg.getRowType.getFieldCount)
-
-    val mapping = Mappings.target(mapOldToNew, fieldCountOfOldAgg, 
newAgg.getRowType.getFieldCount)
-    val newRelOnAgg = createNewRel(mapping, relOnAgg, newAgg)
-    call.transformTo(newRelOnAgg)
-  }
-
-  protected def createNewRel(mapping: Mappings.TargetMapping, project: T, 
newAgg: RelNode): RelNode
-}
-
-class ProjectPruneAggregateCallRule extends 
PruneAggregateCallRule(classOf[Project]) {
-  override protected def getInputRefs(relOnAgg: Project): ImmutableBitSet = {
-    RelOptUtil.InputFinder.bits(relOnAgg.getProjects, null)
-  }
-
-  override protected def createNewRel(
-      mapping: Mappings.TargetMapping,
-      project: Project,
-      newAgg: RelNode): RelNode = {
-    val newProjects = RexUtil.apply(mapping, project.getProjects).toList
-    if (
-      projectsOnlyIdentity(newProjects, newAgg.getRowType.getFieldCount) &&
-      Utilities.compare(project.getRowType.getFieldNames, 
newAgg.getRowType.getFieldNames) == 0
-    ) {
-      newAgg
-    } else {
-      project.copy(project.getTraitSet, newAgg, newProjects, 
project.getRowType)
-    }
-  }
-
-  private def projectsOnlyIdentity(projects: util.List[RexNode], 
inputFieldCount: Int): Boolean = {
-    if (projects.size != inputFieldCount) {
-      return false
-    }
-    projects.zipWithIndex.forall {
-      case (project, index) =>
-        project match {
-          case r: RexInputRef => r.getIndex == index
-          case _ => false
-        }
-    }
-  }
-}
-
-class CalcPruneAggregateCallRule extends PruneAggregateCallRule(classOf[Calc]) 
{
-  override protected def getInputRefs(relOnAgg: Calc): ImmutableBitSet = {
-    val program = relOnAgg.getProgram
-    val condition = if (program.getCondition != null) {
-      program.expandLocalRef(program.getCondition)
-    } else {
-      null
-    }
-    val projects = program.getProjectList.map(program.expandLocalRef)
-    RelOptUtil.InputFinder.bits(projects, condition)
-  }
-
-  override protected def createNewRel(
-      mapping: Mappings.TargetMapping,
-      calc: Calc,
-      newAgg: RelNode): RelNode = {
-    val program = calc.getProgram
-    val newCondition = if (program.getCondition != null) {
-      RexUtil.apply(mapping, program.expandLocalRef(program.getCondition))
-    } else {
-      null
-    }
-    val projects = program.getProjectList.map(program.expandLocalRef)
-    val newProjects = RexUtil.apply(mapping, projects).toList
-    val newProgram = RexProgram.create(
-      newAgg.getRowType,
-      newProjects,
-      newCondition,
-      program.getOutputRowType.getFieldNames,
-      calc.getCluster.getRexBuilder
-    )
-    if (
-      newProgram.isTrivial &&
-      Utilities.compare(calc.getRowType.getFieldNames, 
newAgg.getRowType.getFieldNames) == 0
-    ) {
-      newAgg
-    } else {
-      calc.copy(calc.getTraitSet, newAgg, newProgram)
-    }
-  }
-}
-
-object PruneAggregateCallRule {
-  val PROJECT_ON_AGGREGATE = new ProjectPruneAggregateCallRule
-  val CALC_ON_AGGREGATE = new CalcPruneAggregateCallRule
-}

Reply via email to