This is an automated email from the ASF dual-hosted git repository.
philo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 27d71c8abd [VL] Refactor `getAggRelInternal` in
`HashAggregateExecTransformer` (#11040)
27d71c8abd is described below
commit 27d71c8abdf231a1ff3de1358f2a32ecf89da4dd
Author: Zouxxyy <[email protected]>
AuthorDate: Tue Nov 11 08:53:20 2025 +0800
[VL] Refactor `getAggRelInternal` in `HashAggregateExecTransformer` (#11040)
---
.../execution/HashAggregateExecTransformer.scala | 163 +++++++--------------
1 file changed, 53 insertions(+), 110 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
index e46d5340d0..ad9528246c 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/HashAggregateExecTransformer.scala
@@ -76,15 +76,10 @@ abstract class HashAggregateExecTransformer(
// Return whether the outputs partial aggregation should be combined for
Velox computing.
// When the partial outputs are multiple-column, row construct is needed.
- private def rowConstructNeeded(aggregateExpressions:
Seq[AggregateExpression]): Boolean = {
- aggregateExpressions.exists {
- aggExpr =>
- aggExpr.mode match {
- case PartialMerge | Final =>
- aggExpr.aggregateFunction.inputAggBufferAttributes.size > 1
- case _ => false
- }
- }
+ private def rowConstructNeeded(): Boolean = aggregateExpressions.exists {
+ case AggregateExpression(aggFunc, PartialMerge | Final, _, _, _) =>
+ aggFunc.inputAggBufferAttributes.size > 1
+ case _ => false
}
/**
@@ -186,13 +181,12 @@ abstract class HashAggregateExecTransformer(
s"isStreaming=$isStreamingStr\nallowFlush=$allowFlushStr\n"
}
- // Create aggregate function node and add to list.
- private def addFunctionNode(
+ // Create aggregate function node.
+ private def makeFunctionNode(
context: SubstraitContext,
aggregateFunction: AggregateFunction,
childrenNodeList: JList[ExpressionNode],
- aggregateMode: AggregateMode,
- aggregateNodeList: JList[AggregateFunctionNode]): Unit = {
+ aggregateMode: AggregateMode): AggregateFunctionNode = {
val outputTypeNode = aggregateMode match {
case Partial | PartialMerge if
aggregateFunction.aggBufferAttributes.size > 1 =>
@@ -204,13 +198,12 @@ abstract class HashAggregateExecTransformer(
case Final | Complete =>
ConverterUtils.getTypeNode(aggregateFunction.dataType,
aggregateFunction.nullable)
}
- val aggFunctionNode = ExpressionBuilder.makeAggregateFunction(
+ ExpressionBuilder.makeAggregateFunction(
VeloxAggregateFunctionsBuilder.create(context, aggregateFunction,
aggregateMode),
childrenNodeList,
modeToKeyWord(aggregateMode),
outputTypeNode
)
- aggregateNodeList.add(aggFunctionNode)
}
/**
@@ -271,14 +264,13 @@ abstract class HashAggregateExecTransformer(
// Add a projection node before aggregation for row constructing.
// Mainly used for aggregation whose intermediate type is a compound type in
Velox.
- // Pre-projection is always not required for final stage.
- private def getAggRelWithRowConstruct(
+ // Pre-projection is never required for final stages.
+ private def applyRowConstruct(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
inputRel: RelNode,
validation: Boolean): RelNode = {
- // Create a projection for row construct.
val exprNodes = new JArrayList[ExpressionNode]()
groupingExpressions.foreach(
expr => {
@@ -373,69 +365,17 @@ abstract class HashAggregateExecTransformer(
}
}
- // Create a project rel.
- val projectRel = RelBuilder.makeProjectRel(
+ RelBuilder.makeProjectRel(
originalInputAttributes.asJava,
inputRel,
exprNodes,
context,
operatorId,
validation)
-
- // Create aggregation rel.
- val groupingList = new JArrayList[ExpressionNode]()
- var colIdx = 0
- groupingExpressions.foreach {
- _ =>
- groupingList.add(ExpressionBuilder.makeSelection(colIdx))
- colIdx += 1
- }
-
- val aggFilterList = new JArrayList[ExpressionNode]()
- val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
- aggregateExpressions.foreach(
- aggExpr => {
- if (aggExpr.filter.isDefined) {
- throw new GlutenNotSupportException("Filter in final aggregation is
not supported.")
- } else {
- // The number of filters should be aligned with that of aggregate
functions.
- aggFilterList.add(null)
- }
-
- val aggFunc = aggExpr.aggregateFunction
- val childrenNodes = new JArrayList[ExpressionNode]()
- aggExpr.mode match {
- case PartialMerge | Final =>
- // Only occupies one column due to intermediate results are
combined
- // by previous projection.
- childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
- colIdx += 1
- case Partial | Complete =>
- aggFunc.children.foreach {
- _ =>
- childrenNodes.add(ExpressionBuilder.makeSelection(colIdx))
- colIdx += 1
- }
- case _ =>
- throw new GlutenNotSupportException(
- s"$aggFunc of ${aggExpr.mode.toString} is not supported.")
- }
- addFunctionNode(context, aggFunc, childrenNodes, aggExpr.mode,
aggregateFunctionList)
- })
-
- val extensionNode = getAdvancedExtension()
- RelBuilder.makeAggregateRel(
- projectRel,
- groupingList,
- aggregateFunctionList,
- aggFilterList,
- extensionNode,
- context,
- operatorId)
}
/**
- * Create and return the Rel for the this aggregation.
+ * Create and return the Rel for the aggregation.
* @param context
* the Substrait context
* @param operatorId
@@ -457,13 +397,21 @@ abstract class HashAggregateExecTransformer(
validation: Boolean = false): RelNode = {
val originalInputAttributes = child.output
- var aggRel = if (rowConstructNeeded(aggregateExpressions)) {
+ val finalInput = if (rowConstructNeeded()) {
aggParams.rowConstructionNeeded = true
- getAggRelWithRowConstruct(context, originalInputAttributes, operatorId,
input, validation)
+ applyRowConstruct(context, originalInputAttributes, operatorId, input,
validation)
} else {
- getAggRelInternal(context, originalInputAttributes, operatorId, input,
validation)
+ input
}
+ var aggRel = getAggRelInternal(
+ context,
+ originalInputAttributes,
+ operatorId,
+ finalInput,
+ validation,
+ aggParams.rowConstructionNeeded)
+
if (extractStructNeeded()) {
aggParams.extractionNeeded = true
aggRel = applyExtractStruct(context, aggRel, operatorId, validation)
@@ -515,27 +463,30 @@ abstract class HashAggregateExecTransformer(
context: SubstraitContext,
originalInputAttributes: Seq[Attribute],
operatorId: Long,
- input: RelNode = null,
- validation: Boolean): RelNode = {
- // Get the grouping nodes.
- // Use 'child.output' as based Seq[Attribute], the originalInputAttributes
- // may be different for each backend.
- val groupingList = groupingExpressions
- .map(
+ input: RelNode,
+ validation: Boolean,
+ rowConstructed: Boolean): RelNode = {
+ var colIdx = -1
+ val toExpressionNode: Expression => ExpressionNode = if (rowConstructed) {
+ // If the input is row constructed, use selection to get the column.
+ (_: Expression) =>
+ colIdx += 1
+ ExpressionBuilder.makeSelection(colIdx)
+ } else {
+ (expr: Expression) =>
ExpressionConverter
- .replaceWithExpressionTransformer(_, child.output)
- .doTransform(context))
- .asJava
+ .replaceWithExpressionTransformer(expr, originalInputAttributes)
+ .doTransform(context)
+ }
+
+ val groupingList = groupingExpressions.map(toExpressionNode).asJava
// Get the aggregate function nodes.
val aggFilterList = new JArrayList[ExpressionNode]()
val aggregateFunctionList = new JArrayList[AggregateFunctionNode]()
aggregateExpressions.foreach(
aggExpr => {
if (aggExpr.filter.isDefined) {
- val exprNode = ExpressionConverter
- .replaceWithExpressionTransformer(aggExpr.filter.get, child.output)
- .doTransform(context)
- aggFilterList.add(exprNode)
+ aggFilterList.add(toExpressionNode(aggExpr.filter.get))
} else {
// The number of filters should be aligned with that of aggregate
functions.
aggFilterList.add(null)
@@ -543,33 +494,25 @@ abstract class HashAggregateExecTransformer(
val aggregateFunc = aggExpr.aggregateFunction
val childrenNodes = aggExpr.mode match {
case Partial | Complete =>
- aggregateFunc.children.toList.map(
- expr => {
- ExpressionConverter
- .replaceWithExpressionTransformer(expr,
originalInputAttributes)
- .doTransform(context)
- })
+ aggregateFunc.children.toList.map(toExpressionNode)
case PartialMerge | Final =>
- rewriteAggBufferAttributes(
- aggregateFunc.inputAggBufferAttributes,
- originalInputAttributes).map {
- attr =>
- ExpressionConverter
- .replaceWithExpressionTransformer(attr,
originalInputAttributes)
- .doTransform(context)
+ if (rowConstructed) {
+ // Only occupies one column due to intermediate results are
combined
+ // by previous row construct projection.
+ Seq(toExpressionNode.apply(null))
+ } else {
+ rewriteAggBufferAttributes(
+ aggregateFunc.inputAggBufferAttributes,
+ originalInputAttributes).map(toExpressionNode)
}
case other =>
throw new GlutenNotSupportException(s"$other not supported.")
}
- addFunctionNode(
- context,
- aggregateFunc,
- childrenNodes.asJava,
- aggExpr.mode,
- aggregateFunctionList)
+ aggregateFunctionList.add(
+ makeFunctionNode(context, aggregateFunc, childrenNodes.asJava,
aggExpr.mode))
})
- val extensionNode = getAdvancedExtension(validation,
originalInputAttributes)
+ val extensionNode = getAdvancedExtension(validation && !rowConstructed,
originalInputAttributes)
RelBuilder.makeAggregateRel(
input,
groupingList,
@@ -584,7 +527,7 @@ abstract class HashAggregateExecTransformer(
validation: Boolean = false,
originalInputAttributes: Seq[Attribute] = Seq.empty):
AdvancedExtensionNode = {
val enhancement = if (validation) {
- // Use a extension node to send the input types through Substrait plan
for validation.
+ // Use an extension node to send the input types through Substrait plan
for validation.
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]