[ 
https://issues.apache.org/jira/browse/SPARK-37682?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
 ]

Kevin Liu updated SPARK-37682:
------------------------------
    Description: 
In some cases, current RewriteDistinctAggregates duplicates unnecessary input 
data in distinct groups.
This will cause a lot of waste of memory and affects performance.
We could apply 'merged column' and 'bit vector' tricks to alleviate the 
problem. For example:
{code:sql}
SELECT
  COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_filter_cnt_dist,
  COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_filter_cnt_dist,
  COUNT(DISTINCT IF(value > 5, cat1, null)) as cat1_if_cnt_dist,
  COUNT(DISTINCT id) as id_cnt_dist,
  SUM(DISTINCT value) as id_sum_dist
FROM data
GROUP BY key
{code}

Current rule will rewrite the above sql plan to the following (pseudo) logical 
plan:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count('cat1) FILTER (WHERE (('gid = 1) AND 'max(id > 1))),
       count('(IF((value > 5), cat1, null))) FILTER (WHERE ('gid = 5)),
       count('cat2) FILTER (WHERE (('gid = 3) AND 'max(id > 2))),
       count('id) FILTER (WHERE ('gid = 2)),
       sum('value) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 
'gid]
     functions = [max('id > 1), max('id > 2)]
     output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 
'gid,
               'max(id > 1), 'max(id > 2)])
    Expand(
       projections = [
         ('key, 'cat1, null, null, null, null, 1, ('id > 1), null),
         ('key, null, null, null, null, 'id, 2, null, null),
         ('key, null, null, 'cat2, null, null, 3, null, ('id > 2)),
         ('key, null, 'value, null, null, null, 4, null, null),
         ('key, null, null, null, if (('value > 5)) 'cat1 else null, null, 5, 
null, null)
       ]
       output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 
'id,
                 'gid, '(id > 1), '(id > 2)])
      LocalTableScan [...]
{noformat}

After applying 'merged column' and 'bit vector' tricks, the logical plan will 
become:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count('merged_string_1) FILTER (WHERE (('gid = 1) AND NOT (('vector_1 & 
1) = 0))),
       count('merged_string_1) FILTER (WHERE (('gid = 1) AND NOT (('vector_1 & 
2) = 0))),
       count('merged_string_1) FILTER (WHERE (('gid = 2) AND NOT (('vector_1 & 
1) = 0))),
       count('merged_integer_1) FILTER (WHERE ('gid = 3)),
       sum('merged_integer_1) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'merged_string_1, 'merged_integer_1, 'gid]
     functions = [bit_or('vector_1)]
     output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 
'bit_or(vector_1)])
    Expand(
       projections = [
         ('key, 'cat1, null, 1, (if (('id > 1)) 1 else 0 | if (('value > 5)) 2 
else 0)),
         ('key, 'cat2, null, 2, if (('id > 2)) 1 else 0),
         ('key, null, 'id, 3, null),
         ('key, null, 'value, 4, null)
       ]
       output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'vector_1])
      LocalTableScan [...]
{noformat}

1. merged column: Children with same datatype from different aggregate 
functions can share same project column (e.g. cat1, cat2).
2. bit vector: If multiple aggregate function children have conditional 
expressions, these conditions will output one column when it is true, and 
output null when it is false. The detail logic is in 
RewriteDistinctAggregates.groupDistinctAggExpr of the following github link. 
Then these aggregate functions can share one row group, and store the results 
of their respective conditional expressions in the bit vector column, reducing 
the number of rows of data expansion (e.g. cat1_filter_cnt_dist, 
cat1_if_cnt_dist).
If there are many similar aggregate functions with or without filter in 
distinct, these tricks can save mass memory and improve performance.

  was:
In some cases, current RewriteDistinctAggregates duplicates unnecessary input 
data in distinct groups.
This will cause a lot of waste of memory and affects performance.
We could apply 'merged column' and 'bit vector' tricks to alleviate the 
problem. For example:
{code:sql}
SELECT
  COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_filter_cnt_dist,
  COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_filter_cnt_dist,
  COUNT(DISTINCT IF(value > 5, cat1, null)) as cat1_if_cnt_dist,
  COUNT(DISTINCT id) as id_cnt_dist,
  SUM(DISTINCT value) as id_sum_dist
FROM data
GROUP BY key
{code}

Current rule will rewrite the above sql plan to the following (pseudo) logical 
plan:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count('cat1) FILTER (WHERE (('gid = 1) AND 'max(id > 1))),
       count('(IF((value > 5), cat1, null))) FILTER (WHERE ('gid = 5)),
       count('cat2) FILTER (WHERE (('gid = 3) AND 'max(id > 2))),
       count('id) FILTER (WHERE ('gid = 2)),
       sum('value) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 
'gid]
     functions = [max('id > 1), max('id > 2)]
     output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 
'gid,
               'max(id > 1), 'max(id > 2)])
    Expand(
       projections = [
         ('key, 'cat1, null, null, null, null, 1, ('id > 1), null),
         ('key, null, null, null, null, 'id, 2, null, null),
         ('key, null, null, 'cat2, null, null, 3, null, ('id > 2)),
         ('key, null, 'value, null, null, null, 4, null, null),
         ('key, null, null, null, if (('value > 5)) 'cat1 else null, null, 5, 
null, null)
       ]
       output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 
'id,
                 'gid, '(id > 1), '(id > 2)])
      LocalTableScan [...]
{noformat}

After applying 'merged column' and 'bit vector' tricks, the logical plan will 
become:
{noformat}
Aggregate(
   key = ['key]
   functions = [
       count(if (NOT (('bit_or(vector_1) & 1) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 1)),
       count(if (NOT (('bit_or(vector_1) & 2) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 1)),
       count(if (NOT (('bit_or(vector_1) & 1) = 0)) 'merged_string_1 else null)
         FILTER (WHERE ('gid = 2)),
       count('merged_integer_1) FILTER (WHERE ('gid = 3)),
       sum('merged_integer_1) FILTER (WHERE ('gid = 4))
   ]
   output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
'cat1_if_cnt_dist,
             'id_cnt_dist, 'id_sum_dist])
  Aggregate(
     key = ['key, 'merged_string_1, 'merged_integer_1, 'gid]
     functions = [bit_or('vector_1)]
     output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 
'bit_or(vector_1)])
    Expand(
       projections = [
         ('key, 'cat1, null, 1, (if (('id > 1)) 1 else 0 | if (('value > 5)) 2 
else 0)),
         ('key, 'cat2, null, 2, if (('id > 2)) 1 else 0),
         ('key, null, 'id, 3, null),
         ('key, null, 'value, 4, null)
       ]
       output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'vector_1])
      LocalTableScan [...]
{noformat}

1. merged column: Children with same datatype from different aggregate 
functions can share same project column (e.g. cat1, cat2).
2. bit vector: If multiple aggregate function children have conditional 
expressions, these conditions will output one column when it is true, and 
output null when it is false. The detail logic is in 
RewriteDistinctAggregates.groupDistinctAggExpr of the following github link. 
Then these aggregate functions can share one row group, and store the results 
of their respective conditional expressions in the bit vector column, reducing 
the number of rows of data expansion (e.g. cat1_filter_cnt_dist, 
cat1_if_cnt_dist).
If there are many similar aggregate functions with or without filter in 
distinct, these tricks can save mass memory and improve performance.


> Reduce memory pressure of RewriteDistinctAggregates
> ---------------------------------------------------
>
>                 Key: SPARK-37682
>                 URL: https://issues.apache.org/jira/browse/SPARK-37682
>             Project: Spark
>          Issue Type: Improvement
>          Components: SQL
>    Affects Versions: 3.2.0
>            Reporter: Kevin Liu
>            Priority: Major
>              Labels: performance
>
> In some cases, current RewriteDistinctAggregates duplicates unnecessary input 
> data in distinct groups.
> This will cause a lot of waste of memory and affects performance.
> We could apply 'merged column' and 'bit vector' tricks to alleviate the 
> problem. For example:
> {code:sql}
> SELECT
>   COUNT(DISTINCT cat1) FILTER (WHERE id > 1) as cat1_filter_cnt_dist,
>   COUNT(DISTINCT cat2) FILTER (WHERE id > 2) as cat2_filter_cnt_dist,
>   COUNT(DISTINCT IF(value > 5, cat1, null)) as cat1_if_cnt_dist,
>   COUNT(DISTINCT id) as id_cnt_dist,
>   SUM(DISTINCT value) as id_sum_dist
> FROM data
> GROUP BY key
> {code}
> Current rule will rewrite the above sql plan to the following (pseudo) 
> logical plan:
> {noformat}
> Aggregate(
>    key = ['key]
>    functions = [
>        count('cat1) FILTER (WHERE (('gid = 1) AND 'max(id > 1))),
>        count('(IF((value > 5), cat1, null))) FILTER (WHERE ('gid = 5)),
>        count('cat2) FILTER (WHERE (('gid = 3) AND 'max(id > 2))),
>        count('id) FILTER (WHERE ('gid = 2)),
>        sum('value) FILTER (WHERE ('gid = 4))
>    ]
>    output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
> 'cat1_if_cnt_dist,
>              'id_cnt_dist, 'id_sum_dist])
>   Aggregate(
>      key = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 'id, 
> 'gid]
>      functions = [max('id > 1), max('id > 2)]
>      output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 
> 'id, 'gid,
>                'max(id > 1), 'max(id > 2)])
>     Expand(
>        projections = [
>          ('key, 'cat1, null, null, null, null, 1, ('id > 1), null),
>          ('key, null, null, null, null, 'id, 2, null, null),
>          ('key, null, null, 'cat2, null, null, 3, null, ('id > 2)),
>          ('key, null, 'value, null, null, null, 4, null, null),
>          ('key, null, null, null, if (('value > 5)) 'cat1 else null, null, 5, 
> null, null)
>        ]
>        output = ['key, 'cat1, 'value, 'cat2, '(IF((value > 5), cat1, null)), 
> 'id,
>                  'gid, '(id > 1), '(id > 2)])
>       LocalTableScan [...]
> {noformat}
> After applying 'merged column' and 'bit vector' tricks, the logical plan will 
> become:
> {noformat}
> Aggregate(
>    key = ['key]
>    functions = [
>        count('merged_string_1) FILTER (WHERE (('gid = 1) AND NOT (('vector_1 
> & 1) = 0))),
>        count('merged_string_1) FILTER (WHERE (('gid = 1) AND NOT (('vector_1 
> & 2) = 0))),
>        count('merged_string_1) FILTER (WHERE (('gid = 2) AND NOT (('vector_1 
> & 1) = 0))),
>        count('merged_integer_1) FILTER (WHERE ('gid = 3)),
>        sum('merged_integer_1) FILTER (WHERE ('gid = 4))
>    ]
>    output = ['key, 'cat1_filter_cnt_dist, 'cat2_filter_cnt_dist, 
> 'cat1_if_cnt_dist,
>              'id_cnt_dist, 'id_sum_dist])
>   Aggregate(
>      key = ['key, 'merged_string_1, 'merged_integer_1, 'gid]
>      functions = [bit_or('vector_1)]
>      output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 
> 'bit_or(vector_1)])
>     Expand(
>        projections = [
>          ('key, 'cat1, null, 1, (if (('id > 1)) 1 else 0 | if (('value > 5)) 
> 2 else 0)),
>          ('key, 'cat2, null, 2, if (('id > 2)) 1 else 0),
>          ('key, null, 'id, 3, null),
>          ('key, null, 'value, 4, null)
>        ]
>        output = ['key, 'merged_string_1, 'merged_integer_1, 'gid, 'vector_1])
>       LocalTableScan [...]
> {noformat}
> 1. merged column: Children with same datatype from different aggregate 
> functions can share same project column (e.g. cat1, cat2).
> 2. bit vector: If multiple aggregate function children have conditional 
> expressions, these conditions will output one column when it is true, and 
> output null when it is false. The detail logic is in 
> RewriteDistinctAggregates.groupDistinctAggExpr of the following github link. 
> Then these aggregate functions can share one row group, and store the results 
> of their respective conditional expressions in the bit vector column, 
> reducing the number of rows of data expansion (e.g. cat1_filter_cnt_dist, 
> cat1_if_cnt_dist).
> If there are many similar aggregate functions with or without filter in 
> distinct, these tricks can save mass memory and improve performance.



--
This message was sent by Atlassian Jira
(v8.20.1#820001)

---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org
For additional commands, e-mail: issues-h...@spark.apache.org

Reply via email to