This is an automated email from the ASF dual-hosted git repository.
liujiayi771 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new cce8d24ca9 [GLUTEN-9279] Not pulling out expression from PartialMerge
aggregate function to avoid invalid reference binding in ProjectExecTransformer
(#9280)
cce8d24ca9 is described below
commit cce8d24ca956050be18dbfa171ac4d8f5511ca07
Author: z1wu <[email protected]>
AuthorDate: Mon Apr 21 09:39:38 2025 +0800
[GLUTEN-9279] Not pulling out expression from PartialMerge aggregate
function to avoid invalid reference binding in ProjectExecTransformer (#9280)
---
.../apache/gluten/utils/PullOutProjectHelper.scala | 24 ++++++++++--------
.../org/apache/spark/sql/GlutenQueryTest.scala | 22 ++++++++++++++++
.../GlutenExtensionRewriteRuleSuite.scala | 29 +++++++++++++++++++++-
3 files changed, 64 insertions(+), 11 deletions(-)
diff --git
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
index e4fc114410..4bb09bfcab 100644
---
a/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
+++
b/gluten-substrait/src/main/scala/org/apache/gluten/utils/PullOutProjectHelper.scala
@@ -20,7 +20,7 @@ import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.spark.sql.catalyst.expressions._
-import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateFunction}
+import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression,
AggregateFunction, Complete, Partial}
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.types.{ByteType, DateType, IntegerType, LongType,
ShortType}
@@ -137,16 +137,20 @@ trait PullOutProjectHelper {
protected def rewriteAggregateExpression(
ae: AggregateExpression,
expressionMap: mutable.HashMap[Expression, NamedExpression]):
AggregateExpression = {
- val newAggFuncChildren = ae.aggregateFunction.children.map {
- case literal: Literal => literal
- case other => replaceExpressionWithAttribute(other, expressionMap)
+ ae.mode match {
+ case Partial | Complete =>
+ val newAggFuncChildren = ae.aggregateFunction.children.map {
+ case literal: Literal => literal
+ case other => replaceExpressionWithAttribute(other, expressionMap)
+ }
+ val newAggFunc = ae.aggregateFunction
+ .withNewChildren(newAggFuncChildren)
+ .asInstanceOf[AggregateFunction]
+ val newFilter =
+ ae.filter.map(replaceExpressionWithAttribute(_, expressionMap))
+ ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
+ case _ => ae
}
- val newAggFunc = ae.aggregateFunction
- .withNewChildren(newAggFuncChildren)
- .asInstanceOf[AggregateFunction]
- val newFilter =
- ae.filter.map(replaceExpressionWithAttribute(_, expressionMap))
- ae.copy(aggregateFunction = newAggFunc, filter = newFilter)
}
private def needPreComputeRangeFrameBoundary(bound: Expression): Boolean = {
diff --git
a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
index 21185a3581..23fb44d5ca 100644
--- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
+++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenQueryTest.scala
@@ -429,6 +429,28 @@ abstract class GlutenQueryTest extends PlanTest {
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(plan => tag.runtimeClass.isInstance(plan)))
}
+
+ /**
+ * Check whether the executed plan of a dataframe contains expected number
of expected plans.
+ *
+ * @param df:
+ * the input dataframe.
+ * @param count:
+ * expected number of expected plan.
+ * @param tag:
+ * class of the expected plan.
+ * @tparam T:
+ * type of the expected plan.
+ */
+ def checkGlutenOperatorCount[T <: GlutenPlan](df: DataFrame, count:
Int)(implicit
+ tag: ClassTag[T]): Unit = {
+ val executedPlan = getExecutedPlan(df)
+ assert(
+ executedPlan.count(plan => tag.runtimeClass.isInstance(plan)) == count,
+ s"Expect $count ${tag.runtimeClass.getSimpleName} " +
+ s"in executedPlan:\n ${executedPlan.last}"
+ )
+ }
}
object GlutenQueryTest extends Assertions {
diff --git
a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala
b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala
index a295508150..837d37236b 100644
---
a/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala
+++
b/gluten-ut/test/src/test/scala/org/apache/gluten/extension/GlutenExtensionRewriteRuleSuite.scala
@@ -16,7 +16,7 @@
*/
package org.apache.gluten.extension
-import org.apache.gluten.execution.{ProjectExecTransformer,
WholeStageTransformerSuite}
+import org.apache.gluten.execution.{HashAggregateExecBaseTransformer,
ProjectExecTransformer, WholeStageTransformerSuite}
import org.apache.gluten.utils.BackendTestUtils
import org.apache.spark.SparkConf
@@ -62,4 +62,31 @@ class GlutenExtensionRewriteRuleSuite extends
WholeStageTransformerSuite {
}
)
}
+
+ test("GLUTEN-9279 - Not Pull out expression to avoid invalid reference
binding") {
+ withTable("t") {
+ sql("CREATE TABLE t(f1 String, f2 String, f3 String, f4 String) USING
PARQUET")
+ sql("INSERT INTO t values ('1', '2', '3', '4'), ('11' ,'22', '33', '4')")
+ var expectedProjectCount = 3
+ var noFallback = false
+ if (BackendTestUtils.isCHBackendLoaded()) {
+ // The `RewriteMultiChildrenCount` rule in the Velox-backend is the
root cause of the
+ // additional ProjectExecTransformer, which leads to the invalid
reference binding issue.
+ // We still conduct tests on the CH-backend here to ensure that the
introduced modification
+ // in `PullOutPreProject` has no side effect on the CH-backend.
+ expectedProjectCount = 2
+ noFallback = true
+ }
+ runQueryAndCompare(
+ """
+ |SELECT SUM(f1) / COUNT(DISTINCT f2, f3) FROM t GROUP BY f4;
+ |""".stripMargin,
+ noFallBack = noFallback
+ )(
+ df => {
+ checkGlutenOperatorCount[ProjectExecTransformer](df,
expectedProjectCount)
+ checkGlutenOperatorCount[HashAggregateExecBaseTransformer](df, 4)
+ })
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]