This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5d08325165 Count wildcard alias (#14927)
5d08325165 is described below
commit 5d08325165c1a7b32e5e35164919e83d46735e98
Author: Jay Zhan <[email protected]>
AuthorDate: Wed Mar 5 09:44:42 2025 +0800
Count wildcard alias (#14927)
* fix alias
* append the string
* window count
* add column
* fmt
* rm todo
* fixed partitioned
* fix test
* update doc
* Suggestion to reduce API surface area
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../core/tests/dataframe/dataframe_functions.rs | 6 +-
datafusion/core/tests/dataframe/mod.rs | 306 ++++++++++++++++++---
datafusion/functions-aggregate/src/count.rs | 51 +++-
datafusion/sqllogictest/test_files/subquery.slt | 38 +++
4 files changed, 360 insertions(+), 41 deletions(-)
diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs
b/datafusion/core/tests/dataframe/dataframe_functions.rs
index 28c0740ca7..fec3ab786f 100644
--- a/datafusion/core/tests/dataframe/dataframe_functions.rs
+++ b/datafusion/core/tests/dataframe/dataframe_functions.rs
@@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> {
.build()
.unwrap();
- let expected = "Sort: count(Int64(1)) ASC NULLS LAST
[count(Int64(1)):Int64]\
- \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\
- \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32,
count(Int64(1)):Int64]\
+ let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
+ \n Projection: count(*) [count(*):Int64]\
+ \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]]
[b:UInt32, count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
let formatted_plan = plan.display_indent_schema().to_string();
diff --git a/datafusion/core/tests/dataframe/mod.rs
b/datafusion/core/tests/dataframe/mod.rs
index 1875180d50..43428d6846 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -32,8 +32,7 @@ use arrow::datatypes::{
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_batches;
-use datafusion_expr::utils::COUNT_STAR_EXPANSION;
-use datafusion_functions_aggregate::count::{count_all, count_udaf};
+use datafusion_functions_aggregate::count::{count_all, count_all_window};
use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, count, count_distinct, max, median, min, sum,
};
@@ -2455,7 +2454,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
- .sql("select b,count(1) from t1 group by b order by count(1)")
+ .sql("select b, count(*) from t1 group by b order by count(*)")
.await?
.explain(false, false)?
.collect()
@@ -2469,9 +2468,52 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
.explain(false, false)?
.collect()
.await?;
- //make sure sql plan same with df plan
+
+ let expected_sql_result =
"+---------------+------------------------------------------------------------------------------------------------------------+\
+ \n| plan_type | plan
|\
+
\n+---------------+------------------------------------------------------------------------------------------------------------+\
+ \n| logical_plan | Projection: t1.b, count(*)
|\
+ \n| | Sort: count(Int64(1)) AS count(*) AS count(*) ASC
NULLS LAST |\
+ \n| | Projection: t1.b, count(Int64(1)) AS count(*),
count(Int64(1)) |\
+ \n| | Aggregate: groupBy=[[t1.b]],
aggr=[[count(Int64(1))]] |\
+ \n| | TableScan: t1 projection=[b]
|\
+ \n| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as
count(*)] |\
+ \n| | SortPreservingMergeExec: [count(Int64(1))@2 ASC
NULLS LAST] |\
+ \n| | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST],
preserve_partitioning=[true] |\
+ \n| | ProjectionExec: expr=[b@0 as b,
count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\
+ \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0
as b], aggr=[count(Int64(1))] |\
+ \n| | CoalesceBatchesExec: target_batch_size=8192
|\
+ \n| | RepartitionExec: partitioning=Hash([b@0],
4), input_partitions=4 |\
+ \n| | RepartitionExec:
partitioning=RoundRobinBatch(4), input_partitions=1 |\
+ \n| | AggregateExec: mode=Partial, gby=[b@0
as b], aggr=[count(Int64(1))] |\
+ \n| | DataSourceExec: partitions=1,
partition_sizes=[1] |\
+ \n| |
|\
+
\n+---------------+------------------------------------------------------------------------------------------------------------+";
+
assert_eq!(
- pretty_format_batches(&sql_results)?.to_string(),
+ expected_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
+ let expected_df_result =
"+---------------+--------------------------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+--------------------------------------------------------------------------------+\
+\n| logical_plan | Sort: count(*) ASC NULLS LAST
|\
+\n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS
count(*)]] |\
+\n| | TableScan: t1 projection=[b]
|\
+\n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST]
|\
+\n| | SortExec: expr=[count(*)@1 ASC NULLS LAST],
preserve_partitioning=[true] |\
+\n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b],
aggr=[count(*)] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | RepartitionExec: partitioning=Hash([b@0], 4),
input_partitions=4 |\
+\n| | RepartitionExec:
partitioning=RoundRobinBatch(4), input_partitions=1 |\
+\n| | AggregateExec: mode=Partial, gby=[b@0 as b],
aggr=[count(*)] |\
+\n| | DataSourceExec: partitions=1,
partition_sizes=[1] |\
+\n| |
|\
+\n+---------------+--------------------------------------------------------------------------------+";
+
+ assert_eq!(
+ expected_df_result,
pretty_format_batches(&df_results)?.to_string()
);
Ok(())
@@ -2481,12 +2523,35 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
async fn test_count_wildcard_on_where_in() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
- .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)")
+ .sql("SELECT a, b FROM t1 WHERE a in (SELECT count(*) FROM t2)")
.await?
.explain(false, false)?
.collect()
.await?;
+ let expected_sql_result =
"+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) =
__correlated_sq_1.count(*)
|\
+\n| | TableScan: t1 projection=[a, b]
|\
+\n| | SubqueryAlias: __correlated_sq_1
|\
+\n| | Projection: count(Int64(1)) AS count(*)
|\
+\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
|\
+\n| | TableScan: t2 projection=[]
|\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi,
on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\
+\n| | ProjectionExec: expr=[4 as count(*)]
|\
+\n| | PlaceholderRowExec
|\
+\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS
Int64) as CAST(t1.a AS Int64)] |\
+\n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+\n| |
|\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+";
+
+ assert_eq!(
+ expected_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
// In the same SessionContext, AliasGenerator will increase subquery_alias
id by 1
//
https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
// for compare difference between sql and df logical plan, we need to
create a new SessionContext here
@@ -2509,9 +2574,26 @@ async fn test_count_wildcard_on_where_in() -> Result<()>
{
.collect()
.await?;
+ let actual_df_result=
"+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) =
__correlated_sq_1.count(*)
|\
+\n| | TableScan: t1 projection=[a, b]
|\
+\n| | SubqueryAlias: __correlated_sq_1
|\
+\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS
count(*)]] |\
+\n| | TableScan: t2 projection=[]
|\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi,
on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\
+\n| | ProjectionExec: expr=[4 as count(*)]
|\
+\n| | PlaceholderRowExec
|\
+\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS
Int64) as CAST(t1.a AS Int64)] |\
+\n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+\n| |
|\
+\n+---------------+------------------------------------------------------------------------------------------------------------------------+";
+
// make sure sql plan same with df plan
assert_eq!(
- pretty_format_batches(&sql_results)?.to_string(),
+ actual_df_result,
pretty_format_batches(&df_results)?.to_string()
);
@@ -2522,11 +2604,34 @@ async fn test_count_wildcard_on_where_in() ->
Result<()> {
async fn test_count_wildcard_on_where_exist() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
- .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)")
+ .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)")
.await?
.explain(false, false)?
.collect()
.await?;
+
+ let actual_sql_result =
+
"+---------------+---------------------------------------------------------+\
+ \n| plan_type | plan
|\
+
\n+---------------+---------------------------------------------------------+\
+ \n| logical_plan | LeftSemi Join:
|\
+ \n| | TableScan: t1 projection=[a, b]
|\
+ \n| | SubqueryAlias: __correlated_sq_1
|\
+ \n| | Projection:
|\
+ \n| | Aggregate: groupBy=[[]],
aggr=[[count(Int64(1))]] |\
+ \n| | TableScan: t2 projection=[]
|\
+ \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi
|\
+ \n| | ProjectionExec: expr=[]
|\
+ \n| | PlaceholderRowExec
|\
+ \n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+ \n| |
|\
+
\n+---------------+---------------------------------------------------------+";
+
+ assert_eq!(
+ actual_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
let df_results = ctx
.table("t1")
.await?
@@ -2545,9 +2650,24 @@ async fn test_count_wildcard_on_where_exist() ->
Result<()> {
.collect()
.await?;
- //make sure sql plan same with df plan
+ let actual_df_result =
"+---------------+---------------------------------------------------------------------+\
+ \n| plan_type | plan
|\
+
\n+---------------+---------------------------------------------------------------------+\
+ \n| logical_plan | LeftSemi Join:
|\
+ \n| | TableScan: t1 projection=[a, b]
|\
+ \n| | SubqueryAlias: __correlated_sq_1
|\
+ \n| | Projection:
|\
+ \n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))
AS count(*)]] |\
+ \n| | TableScan: t2 projection=[]
|\
+ \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi
|\
+ \n| | ProjectionExec: expr=[]
|\
+ \n| | PlaceholderRowExec
|\
+ \n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+ \n| |
|\
+
\n+---------------+---------------------------------------------------------------------+";
+
assert_eq!(
- pretty_format_batches(&sql_results)?.to_string(),
+ actual_df_result,
pretty_format_batches(&df_results)?.to_string()
);
@@ -2559,34 +2679,62 @@ async fn test_count_wildcard_on_window() -> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
- .sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING
AND 2 FOLLOWING) from t1")
+ .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING
AND 2 FOLLOWING) from t1")
.await?
.explain(false, false)?
.collect()
.await?;
+
+ let actual_sql_result =
"+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
[...]
+\n| plan_type | plan
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS
FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING
|\
+\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]]
|\
+\n| | TableScan: t1 projection=[a]
|\
+\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC
NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]
|\
+\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field {
name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6
PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal:
false }], mode=[Sorted] |\
+\n| | SortExec: expr=[a@0 DESC],
preserve_partitioning=[false]
|\
+\n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+\n| |
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+";
+
+ assert_eq!(
+ actual_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
let df_results = ctx
.table("t1")
.await?
- .select(vec![Expr::WindowFunction(WindowFunction::new(
- WindowFunctionDefinition::AggregateUDF(count_udaf()),
- vec![Expr::Literal(COUNT_STAR_EXPANSION)],
- ))
- .order_by(vec![Sort::new(col("a"), false, true)])
- .window_frame(WindowFrame::new_bounds(
- WindowFrameUnits::Range,
- WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
- WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
- ))
- .build()
- .unwrap()])?
+ .select(vec![count_all_window()
+ .order_by(vec![Sort::new(col("a"), false, true)])
+ .window_frame(WindowFrame::new_bounds(
+ WindowFrameUnits::Range,
+ WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))),
+ WindowFrameBound::Following(ScalarValue::UInt32(Some(2))),
+ ))
+ .build()
+ .unwrap()])?
.explain(false, false)?
.collect()
.await?;
- //make sure sql plan same with df plan
+ let actual_df_result =
"+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
[...]
+\n| plan_type | plan
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS
FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING
|\
+\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a
DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]]
|\
+\n| | TableScan: t1 projection=[a]
|\
+\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC
NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1))
ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]
|\
+\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY
[t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field {
name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6
PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0,
dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range,
start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal:
false }], mode=[Sorted] |\
+\n| | SortExec: expr=[a@0 DESC],
preserve_partitioning=[false]
|\
+\n| | DataSourceExec: partitions=1, partition_sizes=[1]
|\
+\n| |
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+";
+
assert_eq!(
- pretty_format_batches(&df_results)?.to_string(),
- pretty_format_batches(&sql_results)?.to_string()
+ actual_df_result,
+ pretty_format_batches(&df_results)?.to_string()
);
Ok(())
@@ -2598,12 +2746,28 @@ async fn test_count_wildcard_on_aggregate() ->
Result<()> {
register_alltypes_tiny_pages_parquet(&ctx).await?;
let sql_results = ctx
- .sql("select count(1) from t1")
+ .sql("select count(*) from t1")
.await?
.explain(false, false)?
.collect()
.await?;
+ let actual_sql_result =
+
"+---------------+-----------------------------------------------------+\
+\n| plan_type | plan |\
+\n+---------------+-----------------------------------------------------+\
+\n| logical_plan | Projection: count(Int64(1)) AS count(*) |\
+\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\
+\n| | TableScan: t1 projection=[] |\
+\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\
+\n| | PlaceholderRowExec |\
+\n| | |\
+\n+---------------+-----------------------------------------------------+";
+ assert_eq!(
+ actual_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
// add `.select(vec![count_wildcard()])?` to make sure we can analyze all
node instead of just top node.
let df_results = ctx
.table("t1")
@@ -2614,9 +2778,17 @@ async fn test_count_wildcard_on_aggregate() ->
Result<()> {
.collect()
.await?;
- //make sure sql plan same with df plan
+ let actual_df_result =
"+---------------+---------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+---------------------------------------------------------------+\
+\n| logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS
count(*)]] |\
+\n| | TableScan: t1 projection=[]
|\
+\n| physical_plan | ProjectionExec: expr=[4 as count(*)]
|\
+\n| | PlaceholderRowExec
|\
+\n| |
|\
+\n+---------------+---------------------------------------------------------------+";
assert_eq!(
- pretty_format_batches(&sql_results)?.to_string(),
+ actual_df_result,
pretty_format_batches(&df_results)?.to_string()
);
@@ -2628,16 +2800,51 @@ async fn test_count_wildcard_on_where_scalar_subquery()
-> Result<()> {
let ctx = create_join_context()?;
let sql_results = ctx
- .sql("select a,b from t1 where (select count(1) from t2 where t1.a =
t2.a)>0;")
+ .sql("select a,b from t1 where (select count(*) from t2 where t1.a =
t2.a)>0;")
.await?
.explain(false, false)?
.collect()
.await?;
+ let actual_sql_result =
"+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | Projection: t1.a, t1.b
|\
+\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL
THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\
+\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*),
__scalar_sq_1.__always_true |\
+\n| | Left Join: t1.a = __scalar_sq_1.a
|\
+\n| | TableScan: t1 projection=[a, b]
|\
+\n| | SubqueryAlias: __scalar_sq_1
|\
+\n| | Projection: count(Int64(1)) AS count(*), t2.a,
Boolean(true) AS __always_true |\
+\n| | Aggregate: groupBy=[[t2.a]],
aggr=[[count(Int64(1))]]
|\
+\n| | TableScan: t2 projection=[a]
|\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0
ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | HashJoinExec: mode=Partitioned, join_type=Left,
on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | RepartitionExec: partitioning=Hash([a@0], 4),
input_partitions=1 |\
+\n| | DataSourceExec: partitions=1,
partition_sizes=[1]
|\
+\n| | ProjectionExec: expr=[count(Int64(1))@1 as
count(*), a@0 as a, true as __always_true] |\
+\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0
as a], aggr=[count(Int64(1))] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | RepartitionExec: partitioning=Hash([a@0],
4), input_partitions=4 |\
+\n| | RepartitionExec:
partitioning=RoundRobinBatch(4), input_partitions=1
|\
+\n| | AggregateExec: mode=Partial, gby=[a@0 as
a], aggr=[count(Int64(1))] |\
+\n| | DataSourceExec: partitions=1,
partition_sizes=[1] |\
+\n| |
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+";
+ assert_eq!(
+ actual_sql_result,
+ pretty_format_batches(&sql_results)?.to_string()
+ );
+
// In the same SessionContext, AliasGenerator will increase subquery_alias
id by 1
//
https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43
// for compare difference between sql and df logical plan, we need to
create a new SessionContext here
let ctx = create_join_context()?;
+ let agg_expr = count_all();
+ let agg_expr_col = col(agg_expr.schema_name().to_string());
let df_results = ctx
.table("t1")
.await?
@@ -2646,8 +2853,8 @@ async fn test_count_wildcard_on_where_scalar_subquery()
-> Result<()> {
ctx.table("t2")
.await?
.filter(out_ref_col(DataType::UInt32,
"t1.a").eq(col("t2.a")))?
- .aggregate(vec![], vec![count_all()])?
- .select(vec![col(count_all().to_string())])?
+ .aggregate(vec![], vec![agg_expr])?
+ .select(vec![agg_expr_col])?
.into_unoptimized_plan(),
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
@@ -2657,9 +2864,36 @@ async fn test_count_wildcard_on_where_scalar_subquery()
-> Result<()> {
.collect()
.await?;
- //make sure sql plan same with df plan
+ let actual_df_result =
"+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| plan_type | plan
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\
+\n| logical_plan | Projection: t1.a, t1.b
|\
+\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL
THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\
+\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*),
__scalar_sq_1.__always_true |\
+\n| | Left Join: t1.a = __scalar_sq_1.a
|\
+\n| | TableScan: t1 projection=[a, b]
|\
+\n| | SubqueryAlias: __scalar_sq_1
|\
+\n| | Projection: count(*), t2.a, Boolean(true) AS
__always_true |\
+\n| | Aggregate: groupBy=[[t2.a]],
aggr=[[count(Int64(1)) AS count(*)]]
|\
+\n| | TableScan: t2 projection=[a]
|\
+\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0
ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | HashJoinExec: mode=Partitioned, join_type=Left,
on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | RepartitionExec: partitioning=Hash([a@0], 4),
input_partitions=1 |\
+\n| | DataSourceExec: partitions=1,
partition_sizes=[1]
|\
+\n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0
as a, true as __always_true] |\
+\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0
as a], aggr=[count(*)] |\
+\n| | CoalesceBatchesExec: target_batch_size=8192
|\
+\n| | RepartitionExec: partitioning=Hash([a@0],
4), input_partitions=4 |\
+\n| | RepartitionExec:
partitioning=RoundRobinBatch(4), input_partitions=1
|\
+\n| | AggregateExec: mode=Partial, gby=[a@0 as
a], aggr=[count(*)] |\
+\n| | DataSourceExec: partitions=1,
partition_sizes=[1] |\
+\n| |
|\
+\n+---------------+---------------------------------------------------------------------------------------------------------------------------+";
assert_eq!(
- pretty_format_batches(&sql_results)?.to_string(),
+ actual_df_result,
pretty_format_batches(&df_results)?.to_string()
);
@@ -4228,7 +4462,9 @@ fn create_join_context() -> Result<SessionContext> {
],
)?;
- let ctx = SessionContext::new();
+ let config = SessionConfig::new().with_target_partitions(4);
+ let ctx = SessionContext::new_with_config(config);
+ // let ctx = SessionContext::new();
ctx.register_batch("t1", batch1)?;
ctx.register_batch("t2", batch2)?;
diff --git a/datafusion/functions-aggregate/src/count.rs
b/datafusion/functions-aggregate/src/count.rs
index a3339f0fce..2d995b4a41 100644
--- a/datafusion/functions-aggregate/src/count.rs
+++ b/datafusion/functions-aggregate/src/count.rs
@@ -17,6 +17,7 @@
use ahash::RandomState;
use datafusion_common::stats::Precision;
+use datafusion_expr::expr::WindowFunction;
use
datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_macros::user_doc;
use datafusion_physical_expr::expressions;
@@ -51,7 +52,9 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator,
AggregateUDFImpl,
Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature,
Volatility,
};
-use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
+use datafusion_expr::{
+ Expr, ReversedUDAF, StatisticsArgs, TypeSignature,
WindowFunctionDefinition,
+};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
@@ -79,9 +82,51 @@ pub fn count_distinct(expr: Expr) -> Expr {
))
}
-/// Creates aggregation to count all rows, equivalent to `COUNT(*)`,
`COUNT()`, `COUNT(1)`
+/// Creates aggregation to count all rows.
+///
+/// In SQL this is `SELECT COUNT(*) ... `
+///
+/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`, and is
+/// aliased to a column named `"count(*)"` for backward compatibility.
+///
+/// Example
+/// ```
+/// # use datafusion_functions_aggregate::count::count_all;
+/// # use datafusion_expr::col;
+/// // create `count(*)` expression
+/// let expr = count_all();
+/// assert_eq!(expr.schema_name().to_string(), "count(*)");
+/// // if you need to refer to this column, use the `schema_name` function
+/// let expr = col(expr.schema_name().to_string());
+/// ```
pub fn count_all() -> Expr {
- count(Expr::Literal(COUNT_STAR_EXPANSION))
+ count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)")
+}
+
+/// Creates window aggregation to count all rows.
+///
+/// In SQL this is `SELECT COUNT(*) OVER (..) ... `
+///
+/// The expression is equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)`
+///
+/// Example
+/// ```
+/// # use datafusion_functions_aggregate::count::count_all_window;
+/// # use datafusion_expr::col;
+/// // create `count(*)` OVER ... window function expression
+/// let expr = count_all_window();
+/// assert_eq!(
+/// expr.schema_name().to_string(),
+/// "count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED
FOLLOWING"
+/// );
+/// // if you need to refer to this column, use the `schema_name` function
+/// let expr = col(expr.schema_name().to_string());
+/// ```
+pub fn count_all_window() -> Expr {
+ Expr::WindowFunction(WindowFunction::new(
+ WindowFunctionDefinition::AggregateUDF(count_udaf()),
+ vec![Expr::Literal(COUNT_STAR_EXPANSION)],
+ ))
}
#[user_doc(
diff --git a/datafusion/sqllogictest/test_files/subquery.slt
b/datafusion/sqllogictest/test_files/subquery.slt
index 94c9eaf810..207bb72fd5 100644
--- a/datafusion/sqllogictest/test_files/subquery.slt
+++ b/datafusion/sqllogictest/test_files/subquery.slt
@@ -1393,3 +1393,41 @@ item1 1970-01-01T00:00:03 75
statement ok
drop table source_table;
+
+statement count 0
+drop table t1;
+
+statement count 0
+drop table t2;
+
+statement count 0
+drop table t3;
+
+# test count wildcard
+statement count 0
+create table t1(a int) as values (1);
+
+statement count 0
+create table t2(b int) as values (1);
+
+query I
+SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)
+----
+1
+
+query TT
+explain SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)
+----
+logical_plan
+01)LeftSemi Join:
+02)--TableScan: t1 projection=[a]
+03)--SubqueryAlias: __correlated_sq_1
+04)----Projection:
+05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]
+06)--------TableScan: t2 projection=[]
+
+statement count 0
+drop table t1;
+
+statement count 0
+drop table t2;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]