This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ec292f454 Make Logical Plans more readable by removing extra aliases 
(#10832)
0ec292f454 is described below

commit 0ec292f45404359356ab9125b1b0f5b21a135ab8
Author: Mohamed Abdeen <[email protected]>
AuthorDate: Tue Jun 11 21:37:10 2024 +0300

    Make Logical Plans more readable by removing extra aliases (#10832)
    
    * logical plan: remove unnecessary aliases
    
    * revert EnterMark
    
    * fix docs and benchmarks
    
    * revert id_array change
    
    * add alias counter
    
    * fix alias counter bug
    
    * fix slt test
    
    * fix benchmark results
    
    * revert alias/unalias changes
    
    * remove TODO
    
    * minor fix
    
    * fix benchmark
---
 .../optimizer/src/common_subexpr_eliminate.rs      | 73 ++++++++++++++++------
 datafusion/sqllogictest/test_files/group_by.slt    | 18 +++---
 .../sqllogictest/test_files/tpch/q1.slt.part       |  4 +-
 3 files changed, 65 insertions(+), 30 deletions(-)

diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs 
b/datafusion/optimizer/src/common_subexpr_eliminate.rs
index 6820ba04f0..3ed1309f15 100644
--- a/datafusion/optimizer/src/common_subexpr_eliminate.rs
+++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs
@@ -128,7 +128,7 @@ impl CommonSubexprEliminate {
     fn rewrite_exprs_list(
         &self,
         exprs_list: &[&[Expr]],
-        arrays_list: &[&[Vec<(usize, String)>]],
+        arrays_list: &[&[IdArray]],
         expr_stats: &ExprStats,
         common_exprs: &mut CommonExprs,
     ) -> Result<Vec<Vec<Expr>>> {
@@ -159,7 +159,7 @@ impl CommonSubexprEliminate {
     fn rewrite_expr(
         &self,
         exprs_list: &[&[Expr]],
-        arrays_list: &[&[Vec<(usize, String)>]],
+        arrays_list: &[&[IdArray]],
         input: &LogicalPlan,
         expr_stats: &ExprStats,
         config: &dyn OptimizerConfig,
@@ -480,7 +480,7 @@ fn to_arrays(
     input_schema: DFSchemaRef,
     expr_stats: &mut ExprStats,
     expr_mask: ExprMask,
-) -> Result<Vec<Vec<(usize, String)>>> {
+) -> Result<Vec<IdArray>> {
     expr.iter()
         .map(|e| {
             let mut id_array = vec![];
@@ -739,7 +739,7 @@ fn expr_identifier(expr: &Expr, sub_expr_identifier: 
Identifier) -> Identifier {
 fn expr_to_identifier(
     expr: &Expr,
     expr_stats: &mut ExprStats,
-    id_array: &mut Vec<(usize, Identifier)>,
+    id_array: &mut IdArray,
     input_schema: DFSchemaRef,
     expr_mask: ExprMask,
 ) -> Result<()> {
@@ -769,15 +769,28 @@ struct CommonSubexprRewriter<'a> {
     common_exprs: &'a mut CommonExprs,
     // preorder index, starts from 0.
     down_index: usize,
+    // how many aliases have we seen so far
+    alias_counter: usize,
 }
 
 impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
     type Node = Expr;
 
+    fn f_up(&mut self, expr: Expr) -> Result<Transformed<Self::Node>> {
+        if matches!(expr, Expr::Alias(_)) {
+            self.alias_counter -= 1
+        }
+        Ok(Transformed::no(expr))
+    }
+
     fn f_down(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
         // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to 
generate
         // the `id_array`, which records the expr's identifier used to rewrite 
expr. So if we
         // skip an expr in `ExprIdentifierVisitor`, we should skip it here, 
too.
+        if matches!(expr, Expr::Alias(_)) {
+            self.alias_counter += 1;
+        }
+
         if expr.short_circuits() || expr.is_volatile()? {
             return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump));
         }
@@ -801,15 +814,16 @@ impl TreeNodeRewriter for CommonSubexprRewriter<'_> {
 
             let expr_name = expr.display_name()?;
             self.common_exprs.insert(expr_id.clone(), expr);
-            // Alias this `Column` expr to it original "expr name",
-            // `projection_push_down` optimizer use "expr name" to eliminate 
useless
-            // projections.
-            // TODO: do we really need to alias here?
-            Ok(Transformed::new(
-                col(expr_id).alias(expr_name),
-                true,
-                TreeNodeRecursion::Jump,
-            ))
+
+            // alias the expressions without an `Alias` ancestor node
+            let rewritten = if self.alias_counter > 0 {
+                col(expr_id)
+            } else {
+                self.alias_counter += 1;
+                col(expr_id).alias(expr_name)
+            };
+
+            Ok(Transformed::new(rewritten, true, TreeNodeRecursion::Jump))
         } else {
             Ok(Transformed::no(expr))
         }
@@ -829,6 +843,7 @@ fn replace_common_expr(
         id_array,
         common_exprs,
         down_index: 0,
+        alias_counter: 0,
     })
     .data()
 }
@@ -962,6 +977,26 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn nested_aliases() -> Result<()> {
+        let table_scan = test_table_scan()?;
+
+        let plan = LogicalPlanBuilder::from(table_scan.clone())
+            .project(vec![
+                (col("a") + col("b") - col("c")).alias("alias1") * (col("a") + 
col("b")),
+                col("a") + col("b"),
+            ])?
+            .build()?;
+
+        let expected = "Projection: {test.a + test.b|{test.b}|{test.a}} - 
test.c AS alias1 * {test.a + test.b|{test.b}|{test.a}} AS test.a + test.b, 
{test.a + test.b|{test.b}|{test.a}} AS test.a + test.b\
+        \n  Projection: test.a + test.b AS {test.a + 
test.b|{test.b}|{test.a}}, test.a, test.b, test.c\
+        \n    TableScan: test";
+
+        assert_optimized_plan_eq(expected, &plan);
+
+        Ok(())
+    }
+
     #[test]
     fn aggregate() -> Result<()> {
         let table_scan = test_table_scan()?;
@@ -1006,7 +1041,7 @@ mod test {
             )?
             .build()?;
 
-        let expected = "Projection: {AVG(test.a)|{test.a}} AS AVG(test.a) AS 
col1, {AVG(test.a)|{test.a}} AS AVG(test.a) AS col2, col3, {AVG(test.c)} AS 
AVG(test.c), {my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col4, 
{my_agg(test.a)|{test.a}} AS my_agg(test.a) AS col5, col6, {my_agg(test.c)} AS 
my_agg(test.c)\
+        let expected = "Projection: {AVG(test.a)|{test.a}} AS col1, 
{AVG(test.a)|{test.a}} AS col2, col3, {AVG(test.c)} AS AVG(test.c), 
{my_agg(test.a)|{test.a}} AS col4, {my_agg(test.a)|{test.a}} AS col5, col6, 
{my_agg(test.c)} AS my_agg(test.c)\
         \n  Aggregate: groupBy=[[]], aggr=[[AVG(test.a) AS 
{AVG(test.a)|{test.a}}, my_agg(test.a) AS {my_agg(test.a)|{test.a}}, 
AVG(test.b) AS col3, AVG(test.c) AS {AVG(test.c)}, my_agg(test.b) AS col6, 
my_agg(test.c) AS {my_agg(test.c)}]]\
         \n    TableScan: test";
 
@@ -1042,7 +1077,7 @@ mod test {
             )?
             .build()?;
 
-        let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) 
+ test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\n  Projection: 
UInt32(1) + test.a AS {UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, 
test.b, test.c\n    TableScan: test";
+        let expected = "Aggregate: groupBy=[[]], aggr=[[AVG({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}) AS col2]]\n  Projection: UInt32(1) + test.a AS 
{UInt32(1) + test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\n    
TableScan: test";
 
         assert_optimized_plan_eq(expected, &plan);
 
@@ -1057,7 +1092,7 @@ mod test {
             )?
             .build()?;
 
-        let expected = "Aggregate: groupBy=[[{UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col1, my_agg({UInt32(1) 
+ test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS col2]]\
+        let expected = "Aggregate: groupBy=[[{UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}) AS col1, my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}) AS col2]]\
         \n  Projection: UInt32(1) + test.a AS {UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
         \n    TableScan: test";
 
@@ -1078,8 +1113,8 @@ mod test {
             )?
             .build()?;
 
-        let expected = "Projection: UInt32(1) + test.a, UInt32(1) + 
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + 
test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + 
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32(1) + test.a) 
AS col1, UInt32(1) - {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS 
UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + 
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}} AS AVG(UInt32( [...]
-        \n  Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS 
UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS 
UInt32(1) + test.a) AS {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS 
UInt32(1) + test.a)|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + 
test.a|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}}, my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS {my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}} AS UIn [...]
+        let expected = "Projection: UInt32(1) + test.a, UInt32(1) + 
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}}} AS col1, UInt32(1) - {AVG({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} AS 
col2, {AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a)} AS 
AVG(UInt32(1) + test.a), UInt32(1) + {my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}} A 
[...]
+        \n  Aggregate: groupBy=[[{UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS 
UInt32(1) + test.a]], aggr=[[AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}}) AS 
{AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}}}, my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}) AS {my_agg({UInt32(1) + 
test.a|{test.a}|{UInt32(1)}})|{{UInt32(1) + test.a|{test.a}|{UInt32(1)}}}}, 
AVG({UInt32(1) + test.a|{test.a}|{UInt32(1)}} AS UInt32(1) + test.a) AS 
{AVG({UInt32(1) + t [...]
         \n    Projection: UInt32(1) + test.a AS {UInt32(1) + 
test.a|{test.a}|{UInt32(1)}}, test.a, test.b, test.c\
         \n      TableScan: test";
 
@@ -1126,7 +1161,7 @@ mod test {
             ])?
             .build()?;
 
-        let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS 
Int32(1) + test.a AS first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS Int32(1) 
+ test.a AS second\
+        let expected = "Projection: {Int32(1) + test.a|{test.a}|{Int32(1)}} AS 
first, {Int32(1) + test.a|{test.a}|{Int32(1)}} AS second\
         \n  Projection: Int32(1) + test.a AS {Int32(1) + 
test.a|{test.a}|{Int32(1)}}, test.a, test.b, test.c\
         \n    TableScan: test";
 
diff --git a/datafusion/sqllogictest/test_files/group_by.slt 
b/datafusion/sqllogictest/test_files/group_by.slt
index 24a301d4a7..9e8a2450e0 100644
--- a/datafusion/sqllogictest/test_files/group_by.slt
+++ b/datafusion/sqllogictest/test_files/group_by.slt
@@ -4538,7 +4538,7 @@ CREATE EXTERNAL TABLE timestamp_table (
        c2 INT,
 )
 STORED AS CSV
-LOCATION 'test_files/scratch/group_by/timestamp_table' 
+LOCATION 'test_files/scratch/group_by/timestamp_table'
 OPTIONS ('format.has_header' 'true');
 
 # Group By using date_trunc
@@ -4611,7 +4611,7 @@ DROP TABLE timestamp_table;
 
 # Table with an int column and Dict<Int8> column:
 statement ok
-CREATE TABLE int8_dict AS VALUES 
+CREATE TABLE int8_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(Int8, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(Int8, Utf8)')),
@@ -4649,7 +4649,7 @@ DROP TABLE int8_dict;
 
 # Table with an int column and Dict<Int16> column:
 statement ok
-CREATE TABLE int16_dict AS VALUES 
+CREATE TABLE int16_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(Int16, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(Int16, Utf8)')),
@@ -4687,7 +4687,7 @@ DROP TABLE int16_dict;
 
 # Table with an int column and Dict<Int32> column:
 statement ok
-CREATE TABLE int32_dict AS VALUES 
+CREATE TABLE int32_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(Int32, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(Int32, Utf8)')),
@@ -4725,7 +4725,7 @@ DROP TABLE int32_dict;
 
 # Table with an int column and Dict<Int64> column:
 statement ok
-CREATE TABLE int64_dict AS VALUES 
+CREATE TABLE int64_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(Int64, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(Int64, Utf8)')),
@@ -4763,7 +4763,7 @@ DROP TABLE int64_dict;
 
 # Table with an int column and Dict<UInt8> column:
 statement ok
-CREATE TABLE uint8_dict AS VALUES 
+CREATE TABLE uint8_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(UInt8, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(UInt8, Utf8)')),
@@ -4801,7 +4801,7 @@ DROP TABLE uint8_dict;
 
 # Table with an int column and Dict<UInt16> column:
 statement ok
-CREATE TABLE uint16_dict AS VALUES 
+CREATE TABLE uint16_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(UInt16, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(UInt16, Utf8)')),
@@ -4839,7 +4839,7 @@ DROP TABLE uint16_dict;
 
 # Table with an int column and Dict<UInt32> column:
 statement ok
-CREATE TABLE uint32_dict AS VALUES 
+CREATE TABLE uint32_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(UInt32, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(UInt32, Utf8)')),
@@ -4877,7 +4877,7 @@ DROP TABLE uint32_dict;
 
 # Table with an int column and Dict<UInt64> column:
 statement ok
-CREATE TABLE uint64_dict AS VALUES 
+CREATE TABLE uint64_dict AS VALUES
 (1, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
 (2, arrow_cast('B', 'Dictionary(UInt64, Utf8)')),
 (2, arrow_cast('A', 'Dictionary(UInt64, Utf8)')),
diff --git a/datafusion/sqllogictest/test_files/tpch/q1.slt.part 
b/datafusion/sqllogictest/test_files/tpch/q1.slt.part
index 0583c6ef07..5e0930b992 100644
--- a/datafusion/sqllogictest/test_files/tpch/q1.slt.part
+++ b/datafusion/sqllogictest/test_files/tpch/q1.slt.part
@@ -42,7 +42,7 @@ explain select
 logical_plan
 01)Sort: lineitem.l_returnflag ASC NULLS LAST, lineitem.l_linestatus ASC NULLS 
LAST
 02)--Projection: lineitem.l_returnflag, lineitem.l_linestatus, 
sum(lineitem.l_quantity) AS sum_qty, sum(lineitem.l_extendedprice) AS 
sum_base_price, sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) 
AS sum_disc_price, sum(lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount * Int64(1) + lineitem.l_tax) AS sum_charge, 
AVG(lineitem.l_quantity) AS avg_qty, AVG(lineitem.l_extendedprice) AS 
avg_price, AVG(lineitem.l_discount) AS avg_disc, COUNT(*) AS count_order
-03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], 
aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), 
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - 
lineitem.l_discount)|{Decimal128(Some(1),20,0) - 
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}}
 AS lineitem.l_extendedprice * Decimal128(Some(1),20,0) - lineitem.l_discount) 
AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount),  [...]
+03)----Aggregate: groupBy=[[lineitem.l_returnflag, lineitem.l_linestatus]], 
aggr=[[sum(lineitem.l_quantity), sum(lineitem.l_extendedprice), 
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - 
lineitem.l_discount)|{Decimal128(Some(1),20,0) - 
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}})
 AS sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount), 
sum({lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - lineitem.l_discou 
[...]
 04)------Projection: lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - 
lineitem.l_discount) AS {lineitem.l_extendedprice * (Decimal128(Some(1),20,0) - 
lineitem.l_discount)|{Decimal128(Some(1),20,0) - 
lineitem.l_discount|{lineitem.l_discount}|{Decimal128(Some(1),20,0)}}|{lineitem.l_extendedprice}},
 lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount, 
lineitem.l_tax, lineitem.l_returnflag, lineitem.l_linestatus
 05)--------Filter: lineitem.l_shipdate <= Date32("1998-09-02")
 06)----------TableScan: lineitem projection=[l_quantity, l_extendedprice, 
l_discount, l_tax, l_returnflag, l_linestatus, l_shipdate], 
partial_filters=[lineitem.l_shipdate <= Date32("1998-09-02")]
@@ -80,7 +80,7 @@ group by
     l_linestatus
 order by
     l_returnflag,
-    l_linestatus;
+       l_linestatus;
 ----
 A F 3774200 5320753880.69 5054096266.6828 5256751331.449234 25.537587 
36002.123829 0.050144 147790
 N F 95257 133737795.84 127132372.6512 132286291.229445 25.300664 35521.326916 
0.049394 3765


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to