This is an automated email from the ASF dual-hosted git repository.

zhangzc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e0b11943 [GLUTEN-7054][CH] Fix cse alias issues (#7084)
6e0b11943 is described below

commit 6e0b11943242d18699634df818c1755bf3b5aa40
Author: 李扬 <[email protected]>
AuthorDate: Tue Sep 3 17:22:38 2024 +0800

    [GLUTEN-7054][CH] Fix cse alias issues (#7084)
    
    * fix cse alias issues
    
    * fix issue https://github.com/apache/incubator-gluten/issues/7054
    
    * fix uts
---
 .../CommonSubexpressionEliminateRule.scala         | 12 +++-
 .../hive/GlutenClickHouseHiveTableSuite.scala      | 84 +++++++++++++++++++---
 2 files changed, 84 insertions(+), 12 deletions(-)

diff --git 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
index a3b74366f..52e278b3d 100644
--- 
a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
+++ 
b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala
@@ -28,6 +28,11 @@ import org.apache.spark.sql.internal.SQLConf
 
 import scala.collection.mutable
 
+// If you want to debug CommonSubexpressionEliminateRule, you can:
+// 1. replace all `logTrace` to `logError`
+// 2. append two options to spark config
+//    --conf spark.sql.planChangeLog.level=error
+//    --conf spark.sql.planChangeLog.batches=all
 class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
   extends Rule[LogicalPlan]
   with Logging {
@@ -121,7 +126,12 @@ class CommonSubexpressionEliminateRule(session: 
SparkSession, conf: SQLConf)
         if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
           addToEquivalentExpressions(expr, equivalentExpressions)
         } else {
-          equivalentExpressions.addExprTree(expr)
+          expr match {
+            case alias: Alias =>
+              equivalentExpressions.addExprTree(alias.child)
+            case _ =>
+              equivalentExpressions.addExprTree(expr)
+          }
         }
       })
 
diff --git 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
index f165d7aef..cbc3aed36 100644
--- 
a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
+++ 
b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala
@@ -985,19 +985,19 @@ class GlutenClickHouseHiveTableSuite
     }
   }
 
-  test("GLUTEN-4333: fix CSE in aggregate operator") {
-    def checkOperatorCount[T <: TransformSupport](count: Int)(df: 
DataFrame)(implicit
-        tag: ClassTag[T]): Unit = {
-      if (sparkVersion.equals("3.3")) {
-        assert(
-          getExecutedPlan(df).count(
-            plan => {
-              plan.getClass == tag.runtimeClass
-            }) == count,
-          s"executed plan: ${getExecutedPlan(df)}")
-      }
+  def checkOperatorCount[T <: TransformSupport](count: Int)(df: 
DataFrame)(implicit
+      tag: ClassTag[T]): Unit = {
+    if (sparkVersion.equals("3.3")) {
+      assert(
+        getExecutedPlan(df).count(
+          plan => {
+            plan.getClass == tag.runtimeClass
+          }) == count,
+        s"executed plan: ${getExecutedPlan(df)}")
     }
+  }
 
+  test("GLUTEN-4333: fix CSE in aggregate operator") {
     val createTableSql =
       """
         |CREATE TABLE `test_cse`(
@@ -1262,4 +1262,66 @@ class GlutenClickHouseHiveTableSuite
     compareResultsAgainstVanillaSpark(selectSql, true, _ => {})
     sql(s"drop table if exists $tableName")
   }
+
+  test("GLUTEN-7054: Fix exception when CSE meets common alias expression") {
+    val createTableSql = """
+                           |CREATE TABLE test_tbl_7054 (
+                           |  day STRING,
+                           |  event_id STRING,
+                           |  event STRUCT<
+                           |    event_info: MAP<STRING, STRING>
+                           |  >
+                           |) STORED AS PARQUET;
+                           |""".stripMargin
+
+    val insertDataSql = """
+                          |INSERT INTO test_tbl_7054
+                          |VALUES
+                          |  ('2024-08-27', '011441004',
+                          |     STRUCT(MAP('type', '1', 'action', '8', 
'value_vmoney', '100'))),
+                          |  ('2024-08-27', '011441004',
+                          |     STRUCT(MAP('type', '2', 'action', '8', 
'value_vmoney', '200'))),
+                          |  ('2024-08-27', '011441004',
+                          |     STRUCT(MAP('type', '4', 'action', '8', 
'value_vmoney', '300')));
+                          |""".stripMargin
+
+    val selectSql = """
+                      |SELECT
+                      |  COALESCE(day, 'all') AS daytime,
+                      |  COALESCE(type, 'all') AS type,
+                      |  COALESCE(value_money, 'all') AS value_vmoney,
+                      |  SUM(CASE
+                      |      WHEN type IN (1, 2) AND action = 8 THEN 
value_vmoney
+                      |      ELSE 0
+                      |  END) / 60 AS total_value_vmoney
+                      |FROM (
+                      |  SELECT
+                      |    day,
+                      |    type,
+                      |    NVL(CAST(value_vmoney AS BIGINT), 0) AS value_money,
+                      |    action,
+                      |    type,
+                      |    CAST(value_vmoney AS BIGINT) AS value_vmoney
+                      |  FROM (
+                      |    SELECT
+                      |      day,
+                      |      event.event_info["type"] AS type,
+                      |      event.event_info["action"] AS action,
+                      |      event.event_info["value_vmoney"] AS value_vmoney
+                      |    FROM test_tbl_7054
+                      |    WHERE
+                      |      day = '2024-08-27'
+                      |      AND event_id = '011441004'
+                      |      AND event.event_info["type"] IN (1, 2, 4)
+                      |  ) a
+                      |) b
+                      |GROUP BY
+                      |  day, type, value_money
+                      |""".stripMargin
+
+    spark.sql(createTableSql)
+    spark.sql(insertDataSql)
+    runQueryAndCompare(selectSql)(df => 
checkOperatorCount[ProjectExecTransformer](3)(df))
+    spark.sql("DROP TABLE test_tbl_7054")
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to