rymarm commented on code in PR #3026:
URL: https://github.com/apache/drill/pull/3026#discussion_r2435695203


##########
exec/java-exec/src/main/java/org/apache/drill/exec/physical/impl/union/UnionAllRecordBatch.java:
##########
@@ -301,14 +305,32 @@ private void inferOutputFieldsBothSide(final BatchSchema 
leftSchema, final Batch
           builder.setMinorType(leftField.getType().getMinorType());
           builder = Types.calculateTypePrecisionAndScale(leftField.getType(), 
rightField.getType(), builder);
         } else {
-          TypeProtos.MinorType outputMinorType = 
TypeCastRules.getLeastRestrictiveType(
-            leftField.getType().getMinorType(),
-            rightField.getType().getMinorType()
-          );
-          if (outputMinorType == null) {
-            throw new DrillRuntimeException("Type mismatch between " + 
leftField.getType().getMinorType().toString() +
-                " on the left side and " + 
rightField.getType().getMinorType().toString() +
-                " on the right side in column " + index + " of UNION ALL");
+          TypeProtos.MinorType leftType = leftField.getType().getMinorType();
+          TypeProtos.MinorType rightType = rightField.getType().getMinorType();
+          TypeProtos.MinorType outputMinorType;
+
+          // Special handling for GROUPING SETS expansion:
+          // When unioning different grouping sets, NULL columns are 
represented as INT (Drill's default).
+          // If one side is INT and the other is not, prefer the non-INT type 
since INT is likely a NULL placeholder.
+          if (popConfig.isGroupingSetsExpansion() &&
+              leftType == TypeProtos.MinorType.INT && rightType != 
TypeProtos.MinorType.INT) {
+            // Left is INT (likely NULL placeholder), right is actual data - 
prefer right type
+            outputMinorType = rightType;
+            logger.debug("GROUPING SETS: Preferring {} over INT for column 
{}", rightType, index);
+          } else if (popConfig.isGroupingSetsExpansion() &&
+                     rightType == TypeProtos.MinorType.INT && leftType != 
TypeProtos.MinorType.INT) {
+            // Right is INT (likely NULL placeholder), left is actual data - 
prefer left type
+            outputMinorType = leftType;
+            logger.debug("GROUPING SETS: Preferring {} over INT for column 
{}", leftType, index);
+          } else {
+            // Normal case: use standard type cast rules
+            outputMinorType = TypeCastRules.getLeastRestrictiveType(leftType, 
rightType);
+            if (outputMinorType == null) {
+              throw new DrillRuntimeException("Type mismatch between " + 
leftType.toString() +
+                  " on the left side and " + rightType.toString() +
+                  " on the right side in column " + index + " of UNION ALL");
+            }
+            logger.debug("Using standard type rules: {} + {} -> {}", leftType, 
rightType, outputMinorType);

Review Comment:
   Thanks for this enhancement!
   
   I think the following logic can be moved into a separate helper method for 
clarity. Something like this:
   ```java
   /**
    * Determines the output type for a UNION ALL column when combining two 
types.
    * <p>
    * Special handling is applied for GROUPING SETS expansion:
    * - Drill represents NULL columns as INT during grouping sets expansion.
    * - If one side is INT (likely a NULL placeholder) and the other is not, 
prefer the non-INT type.
    * <p>
    * For all other cases, the least restrictive type according to Drill's type 
cast rules is returned.
    *
    * @param leftType  The type of the left column
    * @param rightType The type of the right column
    * @param index     The column index (for logging)
    * @param popConfig The operator configuration (used to check grouping sets 
expansion)
    * @return The resolved output type
    * @throws DrillRuntimeException if types are incompatible
    */
   private TypeProtos.MinorType resolveUnionColumnType(MaterializedField 
leftField ,
                                                      MaterializedField 
rightField,
                                                      int index) {
       TypeProtos.MinorType leftType = leftField.getType().getMinorType();
       TypeProtos.MinorType rightType = rightField.getType().getMinorType();
   
       boolean isGroupingSets = popConfig.isGroupingSetsExpansion();
       boolean leftIsPlaceholder = leftType == TypeProtos.MinorType.INT && 
rightType != TypeProtos.MinorType.INT;
       boolean rightIsPlaceholder = rightType == TypeProtos.MinorType.INT && 
leftType != TypeProtos.MinorType.INT;
   
       if (isGroupingSets && (leftIsPlaceholder || rightIsPlaceholder)) {
           TypeProtos.MinorType outputType = leftIsPlaceholder ? rightType : 
leftType;
           logger.debug("GROUPING SETS: Preferring {} over INT for column {}", 
outputType, index);
           return outputType;
       }
   
       TypeProtos.MinorType outputType = 
TypeCastRules.getLeastRestrictiveType(leftType, rightType);
       if (outputType == null) {
           throw new DrillRuntimeException("Type mismatch between " + leftType +
               " and " + rightType + " in column " + index + " of UNION ALL");
       }
       logger.debug("Using standard type rules: {} + {} -> {}", leftType, 
rightType, outputType);
       return outputType;
   }
   
   ```



##########
exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java:
##########
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.planner.logical;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.util.ImmutableBitSet;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
+ * of multiple aggregates, each with a single grouping set.
+ *
+ * This rule converts:
+ *   SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
+ *
+ * Into:
+ *   SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
+ *   UNION ALL
+ *   SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
+ *   UNION ALL
+ *   SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
+ *
+ * The $g column is the grouping ID that can be used by GROUPING() and 
GROUPING_ID() functions.
+ * Currently, the $g column is generated internally but stripped from the 
final output.
+ *
+ * TODO: Implement GROUPING() and GROUPING_ID() functions by:
+ * 1. Detecting these functions in the SELECT list during expansion
+ * 2. Rewriting them to reference the $g column (e.g., GROUPING(a) becomes bit 
extraction from $g)
+ * 3. Preserving the $g column in the output when these functions are used
+ */
+public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
+
+  public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
+      new DrillAggregateExpandGroupingSetsRule();
+
+  private DrillAggregateExpandGroupingSetsRule() {
+    super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
+        "DrillAggregateExpandGroupingSetsRule");
+  }
+
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    final Aggregate aggregate = call.rel(0);
+
+    // Only match aggregates with multiple grouping sets
+    // Also only match logical aggregates (not physical ones)
+    return aggregate.getGroupSets().size() > 1
+        && (aggregate instanceof DrillAggregateRel || aggregate instanceof 
LogicalAggregate);
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call) {

Review Comment:
   Consider breaking this 346-line method into smaller helper methods to 
improve readability, maintainability, and ease of support.



##########
exec/java-exec/src/main/java/org/apache/drill/exec/planner/logical/DrillAggregateExpandGroupingSetsRule.java:
##########
@@ -0,0 +1,426 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.drill.exec.planner.logical;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.plan.RelOptCluster;
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.logical.LogicalAggregate;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.util.ImmutableBitSet;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Planner rule that expands GROUPING SETS, ROLLUP, and CUBE into a UNION ALL
+ * of multiple aggregates, each with a single grouping set.
+ *
+ * This rule converts:
+ *   SELECT a, b, SUM(c) FROM t GROUP BY GROUPING SETS ((a, b), (a), ())
+ *
+ * Into:
+ *   SELECT a, b, SUM(c), 0 AS $g FROM t GROUP BY a, b
+ *   UNION ALL
+ *   SELECT a, null, SUM(c), 1 AS $g FROM t GROUP BY a
+ *   UNION ALL
+ *   SELECT null, null, SUM(c), 3 AS $g FROM t GROUP BY ()
+ *
+ * The $g column is the grouping ID that can be used by GROUPING() and 
GROUPING_ID() functions.
+ * Currently, the $g column is generated internally but stripped from the 
final output.
+ *
+ * TODO: Implement GROUPING() and GROUPING_ID() functions by:
+ * 1. Detecting these functions in the SELECT list during expansion
+ * 2. Rewriting them to reference the $g column (e.g., GROUPING(a) becomes bit 
extraction from $g)
+ * 3. Preserving the $g column in the output when these functions are used
+ */
+public class DrillAggregateExpandGroupingSetsRule extends RelOptRule {
+
+  public static final DrillAggregateExpandGroupingSetsRule INSTANCE =
+      new DrillAggregateExpandGroupingSetsRule();
+
+  private DrillAggregateExpandGroupingSetsRule() {
+    super(operand(Aggregate.class, any()), DrillRelFactories.LOGICAL_BUILDER,
+        "DrillAggregateExpandGroupingSetsRule");
+  }
+
+  @Override
+  public boolean matches(RelOptRuleCall call) {
+    final Aggregate aggregate = call.rel(0);
+
+    // Only match aggregates with multiple grouping sets
+    // Also only match logical aggregates (not physical ones)
+    return aggregate.getGroupSets().size() > 1
+        && (aggregate instanceof DrillAggregateRel || aggregate instanceof 
LogicalAggregate);
+  }
+
+  @Override
+  public void onMatch(RelOptRuleCall call) {
+    final Aggregate aggregate = call.rel(0);
+    final RelOptCluster cluster = aggregate.getCluster();
+    final RexBuilder rexBuilder = cluster.getRexBuilder();
+    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+
+    // Get the input
+    final RelNode input = aggregate.getInput();
+    final List<ImmutableBitSet> groupSets = aggregate.getGroupSets();
+    final ImmutableBitSet fullGroupSet = aggregate.getGroupSet();
+    final List<AggregateCall> aggCalls = aggregate.getAggCallList();
+
+    // Check if we have GROUPING, GROUPING_ID, or GROUP_ID functions
+    // These functions need the $g column to be preserved in the output
+    // We need to separate them from regular aggregate functions but preserve 
their original positions
+    List<AggregateCall> regularAggCalls = new ArrayList<>();
+    List<Integer> groupingFunctionPositions = new ArrayList<>();  // Original 
positions in aggCalls
+    List<AggregateCall> groupingFunctionCalls = new ArrayList<>();
+    boolean hasGroupingFunctions = false;
+
+    for (int i = 0; i < aggCalls.size(); i++) {
+      AggregateCall aggCall = aggCalls.get(i);
+      org.apache.calcite.sql.SqlKind kind = aggCall.getAggregation().getKind();
+      if (kind == org.apache.calcite.sql.SqlKind.GROUPING ||
+          kind == org.apache.calcite.sql.SqlKind.GROUPING_ID ||
+          kind == org.apache.calcite.sql.SqlKind.GROUP_ID) {
+        hasGroupingFunctions = true;
+        groupingFunctionPositions.add(i);
+        groupingFunctionCalls.add(aggCall);
+      } else {
+        regularAggCalls.add(aggCall);
+      }
+    }
+
+    // Create a separate aggregate for each grouping set
+    // Process grouping sets in order of decreasing cardinality (more columns 
first)
+    // This ensures that for UNION ALL, branches with actual data types come 
before
+    // branches with NULL placeholders, helping with type inference
+    //
+    // For GROUP_ID support, we need to track duplicate grouping sets and 
assign sequence numbers
+    List<RelNode> aggregates = new ArrayList<>();
+    List<ImmutableBitSet> sortedGroupSets = new ArrayList<>(groupSets);
+    // Sort by cardinality descending (more grouping columns first)
+    sortedGroupSets.sort((a, b) -> Integer.compare(b.cardinality(), 
a.cardinality()));
+
+    // Track GROUP_ID for duplicate grouping sets
+    // Map from grouping set to the count of times we've seen it so far
+    java.util.Map<ImmutableBitSet, Integer> groupSetOccurrences = new 
java.util.HashMap<>();
+    List<Integer> groupIds = new ArrayList<>();  // GROUP_ID value for each 
position in sortedGroupSets
+
+    for (int i = 0; i < sortedGroupSets.size(); i++) {
+      ImmutableBitSet groupSet = sortedGroupSets.get(i);
+
+      // Track GROUP_ID: how many times have we seen this grouping set before?
+      int groupId = groupSetOccurrences.getOrDefault(groupSet, 0);
+      groupIds.add(groupId);
+      groupSetOccurrences.put(groupSet, groupId + 1);
+
+      // Create the aggregate for this grouping set
+      // Use regularAggCalls (without GROUPING functions) because GROUPING 
functions
+      // will be evaluated later using the $g column
+      Aggregate newAggregate;
+      if (aggregate instanceof DrillAggregateRel) {
+        newAggregate = new DrillAggregateRel(
+            cluster,
+            aggregate.getTraitSet(),
+            input,
+            groupSet,
+            ImmutableList.of(groupSet),
+            regularAggCalls);
+      } else {
+        newAggregate = aggregate.copy(
+            aggregate.getTraitSet(),
+            input,
+            groupSet,
+            ImmutableList.of(groupSet),
+            regularAggCalls);
+      }
+
+      // Create a project to add NULLs for missing grouping columns
+      List<RexNode> projects = new ArrayList<>();
+      List<String> fieldNames = new ArrayList<>();
+
+      // Add grouping columns (with NULLs for columns not in this grouping set)
+      int aggOutputIdx = 0;
+      int outputColIdx = 0; // Index in the final output row type
+      for (int col : fullGroupSet) {
+        if (groupSet.get(col)) {
+          // Column is in this grouping set - project it directly from the 
aggregate output
+          RexNode inputRef = rexBuilder.makeInputRef(newAggregate, 
aggOutputIdx);
+          projects.add(inputRef);
+          aggOutputIdx++;
+        } else {
+          // Column is NOT in this grouping set - project a typed NULL literal
+          // Use the expected output type from the original aggregate to 
create a properly typed NULL
+          // This prevents type inference issues in the UNION ALL
+          org.apache.calcite.rel.type.RelDataType nullType =
+              
aggregate.getRowType().getFieldList().get(outputColIdx).getType();
+          // Use makeLiteral with null value and explicit type to create a 
typed NULL
+          projects.add(rexBuilder.makeNullLiteral(nullType));
+        }
+        
fieldNames.add(aggregate.getRowType().getFieldList().get(outputColIdx).getName());
+        outputColIdx++;
+      }
+
+      // Add aggregate result columns (only regular aggregates, not GROUPING 
functions)
+      // We'll use the alias from the original aggregate call
+      for (int j = 0; j < regularAggCalls.size(); j++) {
+        projects.add(rexBuilder.makeInputRef(newAggregate, aggOutputIdx));
+        AggregateCall regCall = regularAggCalls.get(j);
+        String fieldName = regCall.getName() != null ? regCall.getName() : 
("$f" + (fullGroupSet.cardinality() + j));
+        fieldNames.add(fieldName);
+        aggOutputIdx++;
+      }
+
+      // Add grouping ID column ($g)
+      // The grouping ID is a bitmap where bit i is 1 if column i is NOT in 
the grouping set
+      int groupingId = 0;
+      int bitPosition = 0;
+      for (int col : fullGroupSet) {
+        if (!groupSet.get(col)) {
+          groupingId |= (1 << bitPosition);
+        }
+        bitPosition++;
+      }
+      projects.add(rexBuilder.makeLiteral(groupingId, 
typeFactory.createSqlType(org.apache.calcite.sql.type.SqlTypeName.INTEGER), 
true));
+      fieldNames.add("$g");

Review Comment:
   I would prefer extracting `$g` into a named constant (e.g., 
`GROUPING_ID_COLUMN_NAME`) to improve readability and eliminate magic literals 
in the codebase.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to