rymarm commented on code in PR #3026:
URL: https://github.com/apache/drill/pull/3026#discussion_r2444421971


##########
exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java:
##########


Review Comment:
   How about dividing the `onMatch` method to even more methods. And 
   
   Something like this:
   ```java
   /**
    * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
    * of multiple aggregates, each with a single grouping set.
    *
    * This rule converts:
    *   SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
    *
    * Into:
    *   SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
    *   UNION ALL
    *   SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
    *   UNION ALL
    *   SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
    *
    * The $g column is the grouping ID that can be used by GROUPING() and 
GROUPING_ID() functions.
    * Currently, the $g column is generated internally but stripped from the 
final output.
    */
   public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
   
     public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
         new DrillAggregateExpandGroupingSetsRule();
     public static final String GROUPING_ID_COLUMN_NAME = "$g";
     public static final String GROUP_ID_COLUMN_NAME = "$group_id";
     public static final String EXPRESSION_COLUMN_PLACEHOLDER = "EXPR$";
   
     private DrillAggregateExpandGroupingSetsRule() {
       super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
           "DrillAggregateExpandGroupingSetsRule");
     }
   
     @Override
     public boolean matches(RelOptRuleCall call) {
       final Aggregate aggregate = call.rel(0);
       return aggregate.getGroupSets().size() > 1
           && (aggregate instanceof DrillAggregateRel || aggregate instanceof 
LogicalAggregate);
     }
   
     @Override
     public void onMatch(RelOptRuleCall call) {
       final Aggregate aggregate = call.rel(0);
       final RelOptCluster cluster = aggregate.getCluster();
   
       GroupingFunctionAnalysis analysis = 
analyzeGroupingFunctions(aggregate.getAggCallList());
       GroupingSetOrderingResult ordering = 
sortAndAssignGroupIds(aggregate.getGroupSets());
   
       List<RelNode> perGroupAggregates = new ArrayList<>();
       for (int i = 0; i < ordering.sortedGroupSets.size(); i++) {
         perGroupAggregates.add(
             createAggregateForGroupingSet(call, aggregate, 
ordering.sortedGroupSets.get(i),
                 ordering.groupIds.get(i), analysis.regularAggCalls));
       }
   
       RelNode unionResult = buildUnion(cluster, perGroupAggregates);
       RelNode result = buildFinalProject(call, unionResult, aggregate, 
analysis);
   
       call.transformTo(result);
     }
     
     /**
      * Encapsulates analysis results of aggregate calls to determine
      * which are regular aggregates and which are grouping-related
      * functions (GROUPING, GROUPING_ID, GROUP_ID).
      */
     private static class GroupingFunctionAnalysis {
       final boolean hasGroupingFunctions;
       final List<AggregateCall> regularAggCalls;
       final List<AggregateCall> groupingFunctionCalls;
       final List<Integer> groupingFunctionPositions;
   
       GroupingFunctionAnalysis(List<AggregateCall> regularAggCalls,
           List<AggregateCall> groupingFunctionCalls,
           List<Integer> groupingFunctionPositions) {
         this.hasGroupingFunctions = !groupingFunctionPositions.isEmpty();
         this.regularAggCalls = regularAggCalls;
         this.groupingFunctionCalls = groupingFunctionCalls;
         this.groupingFunctionPositions = groupingFunctionPositions;
       }
     }
   
     /**
      * Holds the sorted grouping sets (largest first) and their assigned group 
IDs.
      */
     private static class GroupingSetOrderingResult {
       final List<ImmutableBitSet> sortedGroupSets;
       final List<Integer> groupIds;
       GroupingSetOrderingResult(List<ImmutableBitSet> sortedGroupSets, 
List<Integer> groupIds) {
         this.sortedGroupSets = sortedGroupSets;
         this.groupIds = groupIds;
       }
     }
   
     /**
      * Analyzes aggregate calls to identify which ones are GROUPING-related 
functions.
      *
      * @param aggCalls list of aggregate calls in the original aggregate
      * @return structure classifying grouping and non-grouping calls
      */
     private GroupingFunctionAnalysis 
analyzeGroupingFunctions(List<AggregateCall> aggCalls) {
       List<AggregateCall> regularAggCalls = new ArrayList<>();
       List<AggregateCall> groupingFunctionCalls = new ArrayList<>();
       List<Integer> groupingFunctionPositions = new ArrayList<>();
   
       for (int i = 0; i < aggCalls.size(); i++) {
         AggregateCall aggCall = aggCalls.get(i);
         SqlKind kind = aggCall.getAggregation().getKind();
         switch (kind) {
         case SqlKind.GROUPING:
         case SqlKind.GROUPING_ID:
         case SqlKind.GROUP_ID:
           groupingFunctionPositions.add(i);
           groupingFunctionCalls.add(aggCall);
           break;
         default:
           regularAggCalls.add(aggCall);
         }
       }
   
       return new GroupingFunctionAnalysis(regularAggCalls,
           groupingFunctionCalls, groupingFunctionPositions);
     }
   
     /**
      * Sorts grouping sets by decreasing cardinality and assigns a unique 
group ID
      * for each occurrence. Group IDs are used to distinguish identical sets 
when needed.
      */
     private GroupingSetOrderingResult 
sortAndAssignGroupIds(List<ImmutableBitSet> groupSets) {
       List<ImmutableBitSet> sortedGroupSets = new ArrayList<>(groupSets);
       sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(), 
a.cardinality()));
   
       Map<ImmutableBitSet, Integer> groupSetOccurrences = new HashMap<>();
       List<Integer> groupIds = new ArrayList<>();
   
       for (ImmutableBitSet groupSet : sortedGroupSets) {
         int groupId = groupSetOccurrences.getOrDefault(groupSet, 0);
         groupIds.add(groupId);
         groupSetOccurrences.put(groupSet, groupId + 1);
       }
   
       return new GroupingSetOrderingResult(sortedGroupSets, groupIds);
     }
   
     /**
      * Creates a single-grouping-set aggregate and adds a projection
      * with null-padding and grouping ID columns ($g and $group_id).
      */
     private RelNode createAggregateForGroupingSet(
         RelOptRuleCall call,
         Aggregate originalAgg,
         ImmutableBitSet groupSet,
         int groupId,
         List<AggregateCall> regularAggCalls) {
   
       ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
       RelOptCluster cluster = originalAgg.getCluster();
       RexBuilder rexBuilder = cluster.getRexBuilder();
       RelDataTypeFactory typeFactory = cluster.getTypeFactory();
       RelNode input = originalAgg.getInput();
   
       Aggregate newAggregate;
       if (originalAgg instanceof DrillAggregateRel) {
         newAggregate = new DrillAggregateRel(cluster, 
originalAgg.getTraitSet(), input,
             groupSet, ImmutableList.of(groupSet), regularAggCalls);
       } else {
         newAggregate = originalAgg.copy(originalAgg.getTraitSet(), input, 
groupSet,
             ImmutableList.of(groupSet), regularAggCalls);
       }
   
       List<RexNode> projects = new ArrayList<>();
       List<String> fieldNames = new ArrayList<>();
       int aggOutputIdx = 0;
       int outputColIdx = 0;
   
       // Populate grouping columns (nulls for omitted columns)
       for (int col : fullGroupSet) {
         if (groupSet.get(col)) {
           projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
         } else {
           RelDataType nullType = 
originalAgg.getRowType().getFieldList().get(outputColIdx).getType();
           projects.add(rexBuilder.makeNullLiteral(nullType));
         }
         
fieldNames.add(originalAgg.getRowType().getFieldList().get(outputColIdx++).getName());
       }
   
       // Add regular aggregates
       for (AggregateCall regCall : regularAggCalls) {
         projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx++));
         fieldNames.add(regCall.getName() != null ? regCall.getName() : "agg$" 
+ aggOutputIdx);
       }
   
       // Add grouping ID ($g)
       int groupingId = computeGroupingId(fullGroupSet, groupSet);
       projects.add(rexBuilder.makeLiteral(groupingId,
           typeFactory.createSqlType(SqlTypeName.INTEGER), true));
       fieldNames.add(GROUPING_ID_COLUMN_NAME);
   
       // Add group ID ($group_id)
       projects.add(rexBuilder.makeLiteral(groupId,
           typeFactory.createSqlType(SqlTypeName.INTEGER), true));
       fieldNames.add(GROUP_ID_COLUMN_NAME);
   
       return call.builder().push(newAggregate).project(projects, fieldNames, 
false).build();
     }
   
     private int computeGroupingId(ImmutableBitSet fullGroupSet, 
ImmutableBitSet groupSet) {
       int id = 0;
       int bit = 0;
       for (int col : fullGroupSet) {
         if (!groupSet.get(col)) id |= (1 << bit);
         bit++;
       }
       return id;
     }
   
     /**
      * Combines all per-grouping-set aggregates into a single {@link 
DrillUnionRel}.
      */
     private RelNode buildUnion(RelOptCluster cluster, List<RelNode> 
aggregates) {
       if (aggregates.size() == 1) {
         return aggregates.get(0);
       }
       try {
         List<RelNode> convertedInputs = new ArrayList<>();
         for (RelNode agg : aggregates) {
           convertedInputs.add(convert(agg, 
agg.getTraitSet().plus(DrillRel.DRILL_LOGICAL).simplify()));
         }
         return new DrillUnionRel(cluster,
             cluster.traitSet().plus(DrillRel.DRILL_LOGICAL),
             convertedInputs,
             true,
             true,
             true);
       } catch (InvalidRelException e) {
         throw new RuntimeException("Failed to create DrillUnionRel", e);
       }
     }
   
     /**
      * Constructs the final projection after the UNION, restoring
      * the original output order and evaluating GROUPING(), GROUPING_ID(), and 
GROUP_ID().
      */
     private RelNode buildFinalProject(
         RelOptRuleCall call,
         RelNode unionResult,
         Aggregate aggregate,
         GroupingFunctionAnalysis analysis) {
   
       RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
       RelDataTypeFactory typeFactory = aggregate.getCluster().getTypeFactory();
       ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
       List<RexNode> finalProjects = new ArrayList<>();
       List<String> finalFieldNames = new ArrayList<>();
       int numFields = unionResult.getRowType().getFieldCount();
   
       for (int i = 0; i < fullGroupSet.cardinality(); i++) {
         finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
         
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
       }
   
       if (analysis.hasGroupingFunctions) {
         RexNode gColumnRef = rexBuilder.makeInputRef(unionResult, numFields - 
2);
         RexNode groupIdColumnRef = rexBuilder.makeInputRef(unionResult, 
numFields - 1);
         Map<Integer, AggregateCall> groupingFuncMap = new HashMap<>();
         for (int i = 0; i < analysis.groupingFunctionPositions.size(); i++) {
           groupingFuncMap.put(analysis.groupingFunctionPositions.get(i),
               analysis.groupingFunctionCalls.get(i));
         }
   
         int regularAggIndex = fullGroupSet.cardinality();
         for (int origPos = 0; origPos < aggregate.getAggCallList().size(); 
origPos++) {
           if (groupingFuncMap.containsKey(origPos)) {
             AggregateCall groupingCall = groupingFuncMap.get(origPos);
             String funcName = groupingCall.getAggregation().getName();
             if ("GROUPING".equals(funcName)) {
               processGrouping(groupingCall, fullGroupSet, rexBuilder, 
typeFactory,
                   gColumnRef, finalProjects, finalFieldNames);
             } else if ("GROUPING_ID".equals(funcName)) {
               processGroupingId(groupingCall, fullGroupSet, rexBuilder, 
typeFactory,
                   gColumnRef, finalProjects, finalFieldNames);
             } else if ("GROUP_ID".equals(funcName)) {
               finalProjects.add(groupIdColumnRef);
               String fieldName = groupingCall.getName() != null
                   ? groupingCall.getName()
                   : EXPRESSION_COLUMN_PLACEHOLDER + finalFieldNames.size();
               finalFieldNames.add(fieldName);
             }
           } else {
             finalProjects.add(rexBuilder.makeInputRef(unionResult, 
regularAggIndex));
             
finalFieldNames.add(unionResult.getRowType().getFieldList().get(regularAggIndex).getName());
             regularAggIndex++;
           }
         }
       } else {
         for (int i = fullGroupSet.cardinality(); i < numFields - 2; i++) {
           finalProjects.add(rexBuilder.makeInputRef(unionResult, i));
           
finalFieldNames.add(unionResult.getRowType().getFieldList().get(i).getName());
         }
       }
   
       return call.builder().push(unionResult).project(finalProjects, 
finalFieldNames, false).build();
     }
   
     /**
      * Builds the Rex expression that implements {@code GROUPING(column)}.
      */
     private void processGrouping(AggregateCall groupingCall,
         ImmutableBitSet fullGroupSet,
         RexBuilder rexBuilder,
         RelDataTypeFactory typeFactory,
         RexNode gColumnRef,
         List<RexNode> finalProjects,
         List<String> finalFieldNames) {
   
       if (groupingCall.getArgList().size() != 1) {
         throw new RuntimeException("GROUPING() expects exactly 1 argument");
       }
   
       int columnIndex = groupingCall.getArgList().get(0);
       int bitPosition = 0;
       for (int col : fullGroupSet) {
         if (col == columnIndex) break;
         bitPosition++;
       }
   
       RexNode divisor = rexBuilder.makeLiteral(
           1 << bitPosition, typeFactory.createSqlType(SqlTypeName.INTEGER), 
true);
   
       RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, 
gColumnRef, divisor);
       RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, 
divided,
           rexBuilder.makeLiteral(2, 
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
   
       finalProjects.add(extractBit);
       String fieldName = groupingCall.getName() != null
           ? groupingCall.getName()
           : "EXPR$" + finalFieldNames.size();
       finalFieldNames.add(fieldName);
     }
   
     /**
      * Builds the Rex expression that implements {@code GROUPING_ID(column, 
...)}.
      */
     private void processGroupingId(AggregateCall groupingCall,
         ImmutableBitSet fullGroupSet,
         RexBuilder rexBuilder,
         RelDataTypeFactory typeFactory,
         RexNode gColumnRef,
         List<RexNode> finalProjects,
         List<String> finalFieldNames) {
   
       if (groupingCall.getArgList().isEmpty()) {
         throw new RuntimeException("GROUPING_ID() expects at least one 
argument");
       }
   
       RexNode result = null;
       for (int i = 0; i < groupingCall.getArgList().size(); i++) {
         int columnIndex = groupingCall.getArgList().get(i);
         int bitPosition = 0;
         for (int col : fullGroupSet) {
           if (col == columnIndex) break;
           bitPosition++;
         }
   
         RexNode divisor = rexBuilder.makeLiteral(1 << bitPosition,
             typeFactory.createSqlType(SqlTypeName.INTEGER), true);
   
         RexNode divided = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, 
gColumnRef, divisor);
         RexNode extractBit = rexBuilder.makeCall(SqlStdOperatorTable.MOD, 
divided,
             rexBuilder.makeLiteral(2, 
typeFactory.createSqlType(SqlTypeName.INTEGER), true));
   
         int resultBitPos = groupingCall.getArgList().size() - 1 - i;
         RexNode bitInPosition = (resultBitPos > 0)
             ? rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, extractBit,
             rexBuilder.makeLiteral(1 << resultBitPos,
                 typeFactory.createSqlType(SqlTypeName.INTEGER), true))
             : extractBit;
   
         result = (result == null)
             ? bitInPosition
             : rexBuilder.makeCall(SqlStdOperatorTable.PLUS, result, 
bitInPosition);
       }
   
       finalProjects.add(result);
       String fieldName = groupingCall.getName() != null
           ? groupingCall.getName()
           : "EXPR$" + finalFieldNames.size();
       finalFieldNames.add(fieldName);
     }
   }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to