This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new d594e6257b Relax join keys constraint from Column to any physical
expression for physical join operators (#8991)
d594e6257b is described below
commit d594e6257b34a5ad47112e26d41516aaeb19e6dd
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jan 29 10:02:22 2024 -0800
Relax join keys constraint from Column to any physical expression for
physical join operators (#8991)
* Relex SortMergeJoin join keys
* More
* More
* More
* More
* Fix clippy
* Fix more clippy
* More
* More
* Fix
* Fix
* Use collect_columns
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
.../src/physical_optimizer/enforce_distribution.rs | 291 ++++++++++++---------
.../core/src/physical_optimizer/enforce_sorting.rs | 19 +-
.../core/src/physical_optimizer/join_selection.rs | 79 +++---
.../src/physical_optimizer/projection_pushdown.rs | 49 +++-
.../replace_with_order_preserving_variants.rs | 2 +-
datafusion/core/src/physical_planner.rs | 18 +-
datafusion/core/tests/fuzz_cases/join_fuzz.rs | 8 +-
datafusion/physical-expr/src/equivalence/class.rs | 26 +-
.../physical-expr/src/equivalence/properties.rs | 7 +-
datafusion/physical-plan/src/joins/hash_join.rs | 197 +++++++-------
.../physical-plan/src/joins/sort_merge_join.rs | 152 +++++------
.../physical-plan/src/joins/symmetric_hash_join.rs | 76 +++---
datafusion/physical-plan/src/joins/test_utils.rs | 12 +-
datafusion/physical-plan/src/joins/utils.rs | 177 +++++++++----
datafusion/proto/proto/datafusion.proto | 4 +-
datafusion/proto/src/generated/prost.rs | 4 +-
datafusion/proto/src/physical_plan/mod.rs | 73 ++++--
.../proto/tests/cases/roundtrip_physical_plan.rs | 8 +-
18 files changed, 691 insertions(+), 511 deletions(-)
diff --git a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
index 0c5c2d78b6..fab26c49c2 100644
--- a/datafusion/core/src/physical_optimizer/enforce_distribution.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_distribution.rs
@@ -51,7 +51,7 @@ use datafusion_physical_expr::expressions::{Column, NoOp};
use datafusion_physical_expr::utils::map_columns_before_projection;
use datafusion_physical_expr::{
physical_exprs_equal, EquivalenceProperties, LexRequirementRef,
PhysicalExpr,
- PhysicalSortRequirement,
+ PhysicalExprRef, PhysicalSortRequirement,
};
use datafusion_physical_plan::sorts::sort::SortExec;
use datafusion_physical_plan::unbounded_output;
@@ -285,19 +285,21 @@ fn adjust_input_keys_ordering(
{
match mode {
PartitionMode::Partitioned => {
- let join_constructor =
- |new_conditions: (Vec<(Column, Column)>,
Vec<SortOptions>)| {
- HashJoinExec::try_new(
- left.clone(),
- right.clone(),
- new_conditions.0,
- filter.clone(),
- join_type,
- PartitionMode::Partitioned,
- *null_equals_null,
- )
- .map(|e| Arc::new(e) as _)
- };
+ let join_constructor = |new_conditions: (
+ Vec<(PhysicalExprRef, PhysicalExprRef)>,
+ Vec<SortOptions>,
+ )| {
+ HashJoinExec::try_new(
+ left.clone(),
+ right.clone(),
+ new_conditions.0,
+ filter.clone(),
+ join_type,
+ PartitionMode::Partitioned,
+ *null_equals_null,
+ )
+ .map(|e| Arc::new(e) as _)
+ };
return reorder_partitioned_join_keys(
requirements,
on,
@@ -346,18 +348,20 @@ fn adjust_input_keys_ordering(
..
}) = plan.as_any().downcast_ref::<SortMergeJoinExec>()
{
- let join_constructor =
- |new_conditions: (Vec<(Column, Column)>, Vec<SortOptions>)| {
- SortMergeJoinExec::try_new(
- left.clone(),
- right.clone(),
- new_conditions.0,
- *join_type,
- new_conditions.1,
- *null_equals_null,
- )
- .map(|e| Arc::new(e) as _)
- };
+ let join_constructor = |new_conditions: (
+ Vec<(PhysicalExprRef, PhysicalExprRef)>,
+ Vec<SortOptions>,
+ )| {
+ SortMergeJoinExec::try_new(
+ left.clone(),
+ right.clone(),
+ new_conditions.0,
+ *join_type,
+ new_conditions.1,
+ *null_equals_null,
+ )
+ .map(|e| Arc::new(e) as _)
+ };
return reorder_partitioned_join_keys(
requirements,
on,
@@ -408,12 +412,14 @@ fn adjust_input_keys_ordering(
fn reorder_partitioned_join_keys<F>(
mut join_plan: PlanWithKeyRequirements,
- on: &[(Column, Column)],
+ on: &[(PhysicalExprRef, PhysicalExprRef)],
sort_options: Vec<SortOptions>,
join_constructor: &F,
) -> Result<PlanWithKeyRequirements>
where
- F: Fn((Vec<(Column, Column)>, Vec<SortOptions>)) -> Result<Arc<dyn
ExecutionPlan>>,
+ F: Fn(
+ (Vec<(PhysicalExprRef, PhysicalExprRef)>, Vec<SortOptions>),
+ ) -> Result<Arc<dyn ExecutionPlan>>,
{
let parent_required = &join_plan.data;
let join_key_pairs = extract_join_keys(on);
@@ -788,10 +794,10 @@ fn expected_expr_positions(
Some(indexes)
}
-fn extract_join_keys(on: &[(Column, Column)]) -> JoinKeyPairs {
+fn extract_join_keys(on: &[(PhysicalExprRef, PhysicalExprRef)]) ->
JoinKeyPairs {
let (left_keys, right_keys) = on
.iter()
- .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
+ .map(|(l, r)| (l.clone() as _, r.clone() as _))
.unzip();
JoinKeyPairs {
left_keys,
@@ -802,16 +808,11 @@ fn extract_join_keys(on: &[(Column, Column)]) ->
JoinKeyPairs {
fn new_join_conditions(
new_left_keys: &[Arc<dyn PhysicalExpr>],
new_right_keys: &[Arc<dyn PhysicalExpr>],
-) -> Vec<(Column, Column)> {
+) -> Vec<(PhysicalExprRef, PhysicalExprRef)> {
new_left_keys
.iter()
.zip(new_right_keys.iter())
- .map(|(l_key, r_key)| {
- (
- l_key.as_any().downcast_ref::<Column>().unwrap().clone(),
- r_key.as_any().downcast_ref::<Column>().unwrap().clone(),
- )
- })
+ .map(|(l_key, r_key)| (l_key.clone(), r_key.clone()))
.collect()
}
@@ -1886,8 +1887,8 @@ pub(crate) mod tests {
// Join on (a == b1)
let join_on = vec![(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap())
as _,
)];
for join_type in join_types {
@@ -1905,8 +1906,9 @@ pub(crate) mod tests {
| JoinType::LeftAnti => {
// Join on (a == c)
let top_join_on = vec![(
- Column::new_with_schema("a", &join.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a",
&join.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("c",
&schema()).unwrap()) as _,
)];
let top_join = hash_join_exec(
join.clone(),
@@ -1966,8 +1968,9 @@ pub(crate) mod tests {
// This time we use (b1 == c) for top join
// Join on (b1 == c)
let top_join_on = vec![(
- Column::new_with_schema("b1", &join.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1",
&join.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("c",
&schema()).unwrap()) as _,
)];
let top_join =
@@ -2031,8 +2034,8 @@ pub(crate) mod tests {
// Join on (a == b)
let join_on = vec![(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("b", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _,
)];
let join = hash_join_exec(left, right.clone(), &join_on,
&JoinType::Inner);
@@ -2045,8 +2048,8 @@ pub(crate) mod tests {
// Join on (a1 == c)
let top_join_on = vec![(
- Column::new_with_schema("a1", &projection.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a1",
&projection.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
)];
let top_join = hash_join_exec(
@@ -2076,8 +2079,8 @@ pub(crate) mod tests {
// Join on (a2 == c)
let top_join_on = vec![(
- Column::new_with_schema("a2", &projection.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a2",
&projection.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
)];
let top_join = hash_join_exec(projection, right, &top_join_on,
&JoinType::Inner);
@@ -2110,8 +2113,8 @@ pub(crate) mod tests {
// Join on (a == b)
let join_on = vec![(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("b", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as _,
)];
let join = hash_join_exec(left, right.clone(), &join_on,
&JoinType::Inner);
@@ -2128,8 +2131,8 @@ pub(crate) mod tests {
// Join on (a == c)
let top_join_on = vec![(
- Column::new_with_schema("a", &projection2.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a",
&projection2.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as _,
)];
let top_join = hash_join_exec(projection2, right, &top_join_on,
&JoinType::Inner);
@@ -2174,8 +2177,8 @@ pub(crate) mod tests {
// Join on (a1 == a2)
let join_on = vec![(
- Column::new_with_schema("a1", &left.schema()).unwrap(),
- Column::new_with_schema("a2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("a2", &right.schema()).unwrap())
as _,
)];
let join = hash_join_exec(left, right.clone(), &join_on,
&JoinType::Inner);
@@ -2221,12 +2224,12 @@ pub(crate) mod tests {
// Join on (b1 == b && a1 == a)
let join_on = vec![
(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1",
&left.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("a1", &left.schema()).unwrap(),
- Column::new_with_schema("a", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a1",
&left.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("a",
&right.schema()).unwrap()) as _,
),
];
let join = hash_join_exec(left, right.clone(), &join_on,
&JoinType::Inner);
@@ -2265,16 +2268,16 @@ pub(crate) mod tests {
// Join on (a == a1 and b == b1 and c == c1)
let join_on = vec![
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("c", &schema()).unwrap(),
- Column::new_with_schema("c1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("c1",
&right.schema()).unwrap()) as _,
),
];
let bottom_left_join =
@@ -2293,16 +2296,16 @@ pub(crate) mod tests {
// Join on (c == c1 and b == b1 and a == a1)
let join_on = vec![
(
- Column::new_with_schema("c", &schema()).unwrap(),
- Column::new_with_schema("c1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("c1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
];
let bottom_right_join =
@@ -2311,16 +2314,31 @@ pub(crate) mod tests {
// Join on (B == b1 and C == c and AA = a1)
let top_join_on = vec![
(
- Column::new_with_schema("B",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("B",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("C",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("C",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("AA",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("AA",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
];
@@ -2382,16 +2400,16 @@ pub(crate) mod tests {
// Join on (a == a1 and b == b1 and c == c1)
let join_on = vec![
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("c", &schema()).unwrap(),
- Column::new_with_schema("c1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("c1",
&right.schema()).unwrap()) as _,
),
];
@@ -2414,16 +2432,16 @@ pub(crate) mod tests {
// Join on (c == c1 and b == b1 and a == a1)
let join_on = vec![
(
- Column::new_with_schema("c", &schema()).unwrap(),
- Column::new_with_schema("c1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("c1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
];
let bottom_right_join = ensure_distribution_helper(
@@ -2435,16 +2453,31 @@ pub(crate) mod tests {
// Join on (B == b1 and C == c and AA = a1)
let top_join_on = vec![
(
- Column::new_with_schema("B",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("B",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("C",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("C",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("AA",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("AA",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
];
@@ -2512,12 +2545,12 @@ pub(crate) mod tests {
// Join on (a == a1 and b == b1)
let join_on = vec![
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
];
let bottom_left_join = ensure_distribution_helper(
@@ -2539,16 +2572,16 @@ pub(crate) mod tests {
// Join on (c == c1 and b == b1 and a == a1)
let join_on = vec![
(
- Column::new_with_schema("c", &schema()).unwrap(),
- Column::new_with_schema("c1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("c1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b1",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("a1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("a1",
&right.schema()).unwrap()) as _,
),
];
let bottom_right_join = ensure_distribution_helper(
@@ -2560,16 +2593,31 @@ pub(crate) mod tests {
// Join on (B == b1 and C == c and AA = a1)
let top_join_on = vec![
(
- Column::new_with_schema("B",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("B",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("b1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("C",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("C",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("c",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
(
- Column::new_with_schema("AA",
&bottom_left_projection.schema()).unwrap(),
- Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("AA",
&bottom_left_projection.schema())
+ .unwrap(),
+ ) as _,
+ Arc::new(
+ Column::new_with_schema("a1",
&bottom_right_join.schema()).unwrap(),
+ ) as _,
),
];
@@ -2648,8 +2696,8 @@ pub(crate) mod tests {
// Join on (a == b1)
let join_on = vec![(
- Column::new_with_schema("a", &schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap())
as _,
)];
for join_type in join_types {
@@ -2660,8 +2708,8 @@ pub(crate) mod tests {
// Top join on (a == c)
let top_join_on = vec![(
- Column::new_with_schema("a", &join.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a",
&join.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("c", &schema()).unwrap()) as
_,
)];
let top_join = sort_merge_join_exec(
join.clone(),
@@ -2783,8 +2831,9 @@ pub(crate) mod tests {
// This time we use (b1 == c) for top join
// Join on (b1 == c)
let top_join_on = vec![(
- Column::new_with_schema("b1", &join.schema()).unwrap(),
- Column::new_with_schema("c", &schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1",
&join.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("c",
&schema()).unwrap()) as _,
)];
let top_join = sort_merge_join_exec(
join,
@@ -2933,12 +2982,12 @@ pub(crate) mod tests {
// Join on (b3 == b2 && a3 == a2)
let join_on = vec![
(
- Column::new_with_schema("b3", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b3",
&left.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b2",
&right.schema()).unwrap()) as _,
),
(
- Column::new_with_schema("a3", &left.schema()).unwrap(),
- Column::new_with_schema("a2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a3",
&left.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("a2",
&right.schema()).unwrap()) as _,
),
];
let join = sort_merge_join_exec(left, right.clone(), &join_on,
&JoinType::Inner);
diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs
b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
index 3aa9cdad18..5c46e64a22 100644
--- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs
+++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs
@@ -985,8 +985,8 @@ mod tests {
let right_input = parquet_exec_sorted(&right_schema,
parquet_sort_exprs);
let on = vec![(
- Column::new_with_schema("col_a", &left_schema)?,
- Column::new_with_schema("c", &right_schema)?,
+ Arc::new(Column::new_with_schema("col_a", &left_schema)?) as _,
+ Arc::new(Column::new_with_schema("c", &right_schema)?) as _,
)];
let join = hash_join_exec(left_input, right_input, on, None,
&JoinType::Inner)?;
let physical_plan = sort_exec(vec![sort_expr("a", &join.schema())],
join);
@@ -1639,8 +1639,9 @@ mod tests {
// Join on (nullable_col == col_a)
let join_on = vec![(
- Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
- Column::new_with_schema("col_a", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("nullable_col",
&left.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("col_a",
&right.schema()).unwrap()) as _,
)];
let join_types = vec![
@@ -1711,8 +1712,9 @@ mod tests {
// Join on (nullable_col == col_a)
let join_on = vec![(
- Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
- Column::new_with_schema("col_a", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("nullable_col",
&left.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("col_a",
&right.schema()).unwrap()) as _,
)];
let join_types = vec![
@@ -1785,8 +1787,9 @@ mod tests {
// Join on (nullable_col == col_a)
let join_on = vec![(
- Column::new_with_schema("nullable_col", &left.schema()).unwrap(),
- Column::new_with_schema("col_a", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("nullable_col",
&left.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("col_a",
&right.schema()).unwrap()) as _,
)];
let join = sort_merge_join_exec(left, right, &join_on,
&JoinType::Inner);
diff --git a/datafusion/core/src/physical_optimizer/join_selection.rs
b/datafusion/core/src/physical_optimizer/join_selection.rs
index 083cd5ecab..02626056f6 100644
--- a/datafusion/core/src/physical_optimizer/join_selection.rs
+++ b/datafusion/core/src/physical_optimizer/join_selection.rs
@@ -690,7 +690,7 @@ mod tests_statistical {
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{stats::Precision, JoinType, ScalarValue};
use datafusion_physical_expr::expressions::Column;
- use datafusion_physical_expr::PhysicalExpr;
+ use datafusion_physical_expr::{PhysicalExpr, PhysicalExprRef};
/// Return statistcs for empty table
fn empty_statistics() -> Statistics {
@@ -860,8 +860,10 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
- Column::new_with_schema("small_col",
&small.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()),
+ Arc::new(
+ Column::new_with_schema("small_col",
&small.schema()).unwrap(),
+ ),
)],
None,
&JoinType::Left,
@@ -914,8 +916,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
- Column::new_with_schema("small_col",
&small.schema()).unwrap(),
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("small_col",
&small.schema()).unwrap(),
+ ),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()),
)],
None,
&JoinType::Left,
@@ -970,8 +974,13 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
- Column::new_with_schema("big_col",
&big.schema()).unwrap(),
- Column::new_with_schema("small_col",
&small.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("big_col",
&big.schema()).unwrap(),
+ ),
+ Arc::new(
+ Column::new_with_schema("small_col",
&small.schema())
+ .unwrap(),
+ ),
)],
None,
&join_type,
@@ -1040,8 +1049,8 @@ mod tests_statistical {
Arc::clone(&big),
Arc::clone(&small),
vec![(
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
- Column::new_with_schema("small_col", &small.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()),
+ Arc::new(Column::new_with_schema("small_col",
&small.schema()).unwrap()),
)],
None,
&JoinType::Inner,
@@ -1056,8 +1065,10 @@ mod tests_statistical {
Arc::clone(&medium),
Arc::new(child_join),
vec![(
- Column::new_with_schema("medium_col",
&medium.schema()).unwrap(),
- Column::new_with_schema("small_col", &child_schema).unwrap(),
+ Arc::new(
+ Column::new_with_schema("medium_col",
&medium.schema()).unwrap(),
+ ),
+ Arc::new(Column::new_with_schema("small_col",
&child_schema).unwrap()),
)],
None,
&JoinType::Left,
@@ -1094,8 +1105,10 @@ mod tests_statistical {
Arc::clone(&small),
Arc::clone(&big),
vec![(
- Column::new_with_schema("small_col",
&small.schema()).unwrap(),
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
+ Arc::new(
+ Column::new_with_schema("small_col",
&small.schema()).unwrap(),
+ ),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()),
)],
None,
&JoinType::Inner,
@@ -1178,8 +1191,8 @@ mod tests_statistical {
));
let join_on = vec![(
- Column::new_with_schema("small_col", &small.schema()).unwrap(),
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("small_col",
&small.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
@@ -1190,8 +1203,8 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
- Column::new_with_schema("small_col", &small.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("small_col",
&small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
big.clone(),
@@ -1202,8 +1215,8 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("small_col", &small.schema()).unwrap(),
- Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("small_col",
&small.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("empty_col",
&empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(
small.clone(),
@@ -1214,8 +1227,8 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
- Column::new_with_schema("small_col", &small.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("empty_col",
&empty.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("small_col",
&small.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
@@ -1244,8 +1257,9 @@ mod tests_statistical {
));
let join_on = vec![(
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
- Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("bigger_col",
&bigger.schema()).unwrap())
+ as _,
)];
check_join_partition_mode(
big.clone(),
@@ -1256,8 +1270,9 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("bigger_col", &bigger.schema()).unwrap(),
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("bigger_col",
&bigger.schema()).unwrap())
+ as _,
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
bigger.clone(),
@@ -1268,8 +1283,8 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("empty_col",
&empty.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
)];
check_join_partition_mode(
empty.clone(),
@@ -1280,8 +1295,8 @@ mod tests_statistical {
);
let join_on = vec![(
- Column::new_with_schema("big_col", &big.schema()).unwrap(),
- Column::new_with_schema("empty_col", &empty.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("big_col",
&big.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("empty_col",
&empty.schema()).unwrap()) as _,
)];
check_join_partition_mode(big, empty, join_on, false,
PartitionMode::Partitioned);
}
@@ -1289,7 +1304,7 @@ mod tests_statistical {
fn check_join_partition_mode(
left: Arc<StatisticsExec>,
right: Arc<StatisticsExec>,
- on: Vec<(Column, Column)>,
+ on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
is_swapped: bool,
expected_mode: PartitionMode,
) {
@@ -1748,8 +1763,8 @@ mod hash_join_tests {
Arc::clone(&left_exec),
Arc::clone(&right_exec),
vec![(
- Column::new_with_schema("a", &left_exec.schema())?,
- Column::new_with_schema("b", &right_exec.schema())?,
+ Arc::new(Column::new_with_schema("a", &left_exec.schema())?),
+ Arc::new(Column::new_with_schema("b", &right_exec.schema())?),
)],
None,
&t.initial_join_type,
diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
index 2d20c487e4..301a97bba4 100644
--- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs
+++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs
@@ -44,10 +44,11 @@ use crate::physical_plan::{Distribution, ExecutionPlan};
use arrow_schema::SchemaRef;
use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNode, VisitRecursion};
-use datafusion_common::JoinSide;
+use datafusion_common::{DataFusionError, JoinSide};
use datafusion_physical_expr::expressions::{Column, Literal};
use datafusion_physical_expr::{
- Partitioning, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
+ Partitioning, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
+ PhysicalSortRequirement,
};
use datafusion_physical_plan::streaming::StreamingTableExec;
use datafusion_physical_plan::union::UnionExec;
@@ -1000,8 +1001,8 @@ fn join_table_borders(
fn update_join_on(
proj_left_exprs: &[(Column, String)],
proj_right_exprs: &[(Column, String)],
- hash_join_on: &[(Column, Column)],
-) -> Option<Vec<(Column, Column)>> {
+ hash_join_on: &[(PhysicalExprRef, PhysicalExprRef)],
+) -> Option<Vec<(PhysicalExprRef, PhysicalExprRef)>> {
// TODO: Clippy wants the "map" call removed, but doing so generates
// a compilation error. Remove the clippy directive once this
// issue is fixed.
@@ -1024,17 +1025,41 @@ fn update_join_on(
/// operation based on a set of equi-join conditions (`hash_join_on`) and a
/// list of projection expressions (`projection_exprs`).
fn new_columns_for_join_on(
- hash_join_on: &[&Column],
+ hash_join_on: &[&PhysicalExprRef],
projection_exprs: &[(Column, String)],
-) -> Option<Vec<Column>> {
+) -> Option<Vec<PhysicalExprRef>> {
let new_columns = hash_join_on
.iter()
.filter_map(|on| {
- projection_exprs
- .iter()
- .enumerate()
- .find(|(_, (proj_column, _))| on.name() == proj_column.name())
- .map(|(index, (_, alias))| Column::new(alias, index))
+ // Rewrite all columns in `on`
+ (*on)
+ .clone()
+ .transform(&|expr| {
+ if let Some(column) =
expr.as_any().downcast_ref::<Column>() {
+ // Find the column in the projection expressions
+ let new_column = projection_exprs
+ .iter()
+ .enumerate()
+ .find(|(_, (proj_column, _))| {
+ column.name() == proj_column.name()
+ })
+ .map(|(index, (_, alias))| Column::new(alias,
index));
+ if let Some(new_column) = new_column {
+ Ok(Transformed::Yes(Arc::new(new_column)))
+ } else {
+ // If the column is not found in the projection
expressions,
+ // it means that the column is not projected. In
this case,
+ // we cannot push the projection down.
+ Err(DataFusionError::Internal(format!(
+ "Column {:?} not found in projection
expressions",
+ column
+ )))
+ }
+ } else {
+ Ok(Transformed::No(expr))
+ }
+ })
+ .ok()
})
.collect::<Vec<_>>();
(new_columns.len() == hash_join_on.len()).then_some(new_columns)
@@ -2018,7 +2043,7 @@ mod tests {
let join: Arc<dyn ExecutionPlan> =
Arc::new(SymmetricHashJoinExec::try_new(
left_csv,
right_csv,
- vec![(Column::new("b", 1), Column::new("c", 2))],
+ vec![(Arc::new(Column::new("b", 1)), Arc::new(Column::new("c",
2)))],
// b_left-(1+a_right)<=a_right+c_left
Some(JoinFilter::new(
Arc::new(BinaryExpr::new(
diff --git
a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
index 4656b5b270..bc9bd0010d 100644
---
a/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
+++
b/datafusion/core/src/physical_optimizer/replace_with_order_preserving_variants.rs
@@ -1440,7 +1440,7 @@ mod tests {
HashJoinExec::try_new(
left,
right,
- vec![(left_col.clone(), right_col.clone())],
+ vec![(Arc::new(left_col.clone()),
Arc::new(right_col.clone()))],
None,
&JoinType::Inner,
PartitionMode::Partitioned,
diff --git a/datafusion/core/src/physical_planner.rs
b/datafusion/core/src/physical_planner.rs
index d383ddce92..d4ef40493d 100644
--- a/datafusion/core/src/physical_planner.rs
+++ b/datafusion/core/src/physical_planner.rs
@@ -1036,15 +1036,21 @@ impl DefaultPhysicalPlanner {
let [physical_left, physical_right]: [Arc<dyn
ExecutionPlan>; 2] = left_right.try_into().map_err(|_|
DataFusionError::Internal("`create_initial_plan_multi` is
broken".to_string()))?;
let left_df_schema = left.schema();
let right_df_schema = right.schema();
+ let execution_props = session_state.execution_props();
let join_on = keys
.iter()
.map(|(l, r)| {
- let l = l.try_into_col()?;
- let r = r.try_into_col()?;
- Ok((
- Column::new(&l.name,
left_df_schema.index_of_column(&l)?),
- Column::new(&r.name,
right_df_schema.index_of_column(&r)?),
- ))
+ let l = create_physical_expr(
+ l,
+ left_df_schema,
+ execution_props
+ )?;
+ let r = create_physical_expr(
+ r,
+ right_df_schema,
+ execution_props
+ )?;
+ Ok((l, r))
})
.collect::<Result<join_utils::JoinOn>>()?;
diff --git a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
index ac86364f42..1c819ac466 100644
--- a/datafusion/core/tests/fuzz_cases/join_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/join_fuzz.rs
@@ -109,12 +109,12 @@ async fn run_join_test(
let schema2 = input2[0].schema();
let on_columns = vec![
(
- Column::new_with_schema("a", &schema1).unwrap(),
- Column::new_with_schema("a", &schema2).unwrap(),
+ Arc::new(Column::new_with_schema("a", &schema1).unwrap()) as _,
+ Arc::new(Column::new_with_schema("a", &schema2).unwrap()) as _,
),
(
- Column::new_with_schema("b", &schema1).unwrap(),
- Column::new_with_schema("b", &schema2).unwrap(),
+ Arc::new(Column::new_with_schema("b", &schema1).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b", &schema2).unwrap()) as _,
),
];
diff --git a/datafusion/physical-expr/src/equivalence/class.rs
b/datafusion/physical-expr/src/equivalence/class.rs
index f0bd1740d5..1f79701871 100644
--- a/datafusion/physical-expr/src/equivalence/class.rs
+++ b/datafusion/physical-expr/src/equivalence/class.rs
@@ -19,7 +19,7 @@ use super::{add_offset_to_expr, collapse_lex_req,
ProjectionMapping};
use crate::{
expressions::Column, physical_expr::deduplicate_physical_exprs,
physical_exprs_bag_equal, physical_exprs_contains, LexOrdering,
LexOrderingRef,
- LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalSortExpr,
+ LexRequirement, LexRequirementRef, PhysicalExpr, PhysicalExprRef,
PhysicalSortExpr,
PhysicalSortRequirement,
};
use datafusion_common::tree_node::TreeNode;
@@ -427,7 +427,7 @@ impl EquivalenceGroup {
right_equivalences: &Self,
join_type: &JoinType,
left_size: usize,
- on: &[(Column, Column)],
+ on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Self {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Full |
JoinType::Right => {
@@ -445,9 +445,25 @@ impl EquivalenceGroup {
// are equal in the resulting table.
if join_type == &JoinType::Inner {
for (lhs, rhs) in on.iter() {
- let index = rhs.index() + left_size;
- let new_lhs = Arc::new(lhs.clone()) as _;
- let new_rhs = Arc::new(Column::new(rhs.name(), index))
as _;
+ let new_lhs = lhs.clone() as _;
+ // Rewrite rhs to point to the right side of the join:
+ let new_rhs = rhs
+ .clone()
+ .transform(&|expr| {
+ if let Some(column) =
+ expr.as_any().downcast_ref::<Column>()
+ {
+ let new_column = Arc::new(Column::new(
+ column.name(),
+ column.index() + left_size,
+ ))
+ as _;
+ return Ok(Transformed::Yes(new_column));
+ }
+
+ Ok(Transformed::No(expr))
+ })
+ .unwrap();
result.add_equal_conditions(&new_lhs, &new_rhs);
}
}
diff --git a/datafusion/physical-expr/src/equivalence/properties.rs
b/datafusion/physical-expr/src/equivalence/properties.rs
index cd0ae09a92..2471d9249e 100644
--- a/datafusion/physical-expr/src/equivalence/properties.rs
+++ b/datafusion/physical-expr/src/equivalence/properties.rs
@@ -23,11 +23,12 @@ use super::ordering::collapse_lex_ordering;
use crate::equivalence::{
collapse_lex_req, EquivalenceGroup, OrderingEquivalenceClass,
ProjectionMapping,
};
-use crate::expressions::{Column, Literal};
+use crate::expressions::Literal;
use crate::sort_properties::{ExprOrdering, SortProperties};
use crate::{
physical_exprs_contains, LexOrdering, LexOrderingRef, LexRequirement,
- LexRequirementRef, PhysicalExpr, PhysicalSortExpr, PhysicalSortRequirement,
+ LexRequirementRef, PhysicalExpr, PhysicalExprRef, PhysicalSortExpr,
+ PhysicalSortRequirement,
};
use arrow_schema::SchemaRef;
@@ -1099,7 +1100,7 @@ pub fn join_equivalence_properties(
join_schema: SchemaRef,
maintains_input_order: &[bool],
probe_side: Option<JoinSide>,
- on: &[(Column, Column)],
+ on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> EquivalenceProperties {
let left_size = left.schema.fields.len();
let mut result = EquivalenceProperties::new(join_schema);
diff --git a/datafusion/physical-plan/src/joins/hash_join.rs
b/datafusion/physical-plan/src/joins/hash_join.rs
index 0c213f4257..cd8b17d135 100644
--- a/datafusion/physical-plan/src/joins/hash_join.rs
+++ b/datafusion/physical-plan/src/joins/hash_join.rs
@@ -30,7 +30,6 @@ use crate::joins::utils::{
};
use crate::{
coalesce_partitions::CoalescePartitionsExec,
- expressions::Column,
expressions::PhysicalSortExpr,
hash_utils::create_hashes,
joins::utils::{
@@ -39,8 +38,8 @@ use crate::{
BuildProbeJoinMetrics, ColumnIndex, JoinFilter, JoinOn,
StatefulStreamResult,
},
metrics::{ExecutionPlanMetricsSet, MetricsSet},
- DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr,
- RecordBatchStream, SendableRecordBatchStream, Statistics,
+ DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
RecordBatchStream,
+ SendableRecordBatchStream, Statistics,
};
use crate::{handle_state, DisplayAs};
@@ -67,7 +66,7 @@ use datafusion_common::{
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
-use datafusion_physical_expr::EquivalenceProperties;
+use datafusion_physical_expr::{EquivalenceProperties, PhysicalExprRef};
use ahash::RandomState;
use futures::{ready, Stream, StreamExt, TryStreamExt};
@@ -278,7 +277,7 @@ pub struct HashJoinExec {
/// right (probe) side which are filtered by the hash table
pub right: Arc<dyn ExecutionPlan>,
/// Set of equijoin columns from the relations: `(left_col, right_col)`
- pub on: Vec<(Column, Column)>,
+ pub on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
/// Filters which are applied while finding matching rows
pub filter: Option<JoinFilter>,
/// How the join is performed (`OUTER`, `INNER`, etc)
@@ -369,7 +368,7 @@ impl HashJoinExec {
}
/// Set of common columns used to join on
- pub fn on(&self) -> &[(Column, Column)] {
+ pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
@@ -451,16 +450,8 @@ impl ExecutionPlan for HashJoinExec {
Distribution::UnspecifiedDistribution,
],
PartitionMode::Partitioned => {
- let (left_expr, right_expr) = self
- .on
- .iter()
- .map(|(l, r)| {
- (
- Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
- Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
- )
- })
- .unzip();
+ let (left_expr, right_expr) =
+ self.on.iter().map(|(l, r)| (l.clone(),
r.clone())).unzip();
vec![
Distribution::HashPartitioned(left_expr),
Distribution::HashPartitioned(right_expr),
@@ -697,7 +688,7 @@ async fn collect_left_input(
partition: Option<usize>,
random_state: RandomState,
left: Arc<dyn ExecutionPlan>,
- on_left: Vec<Column>,
+ on_left: Vec<PhysicalExprRef>,
context: Arc<TaskContext>,
metrics: BuildProbeJoinMetrics,
reservation: MemoryReservation,
@@ -793,7 +784,7 @@ async fn collect_left_input(
/// as a chain head for rows with equal hash values.
#[allow(clippy::too_many_arguments)]
pub fn update_hash<T>(
- on: &[Column],
+ on: &[PhysicalExprRef],
batch: &RecordBatch,
hash_map: &mut T,
offset: usize,
@@ -955,9 +946,9 @@ struct HashJoinStream {
/// Input schema
schema: Arc<Schema>,
/// equijoin columns from the left (build side)
- on_left: Vec<Column>,
+ on_left: Vec<PhysicalExprRef>,
/// equijoin columns from the right (probe side)
- on_right: Vec<Column>,
+ on_right: Vec<PhysicalExprRef>,
/// optional join filter
filter: Option<JoinFilter>,
/// type of the join (left, right, semi, etc)
@@ -1043,8 +1034,8 @@ fn lookup_join_hashmap(
build_hashmap: &JoinHashMap,
build_input_buffer: &RecordBatch,
probe_batch: &RecordBatch,
- build_on: &[Column],
- probe_on: &[Column],
+ build_on: &[PhysicalExprRef],
+ probe_on: &[PhysicalExprRef],
null_equals_null: bool,
hashes_buffer: &[u64],
limit: usize,
@@ -1437,6 +1428,7 @@ mod tests {
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{BinaryExpr, Literal};
+ use datafusion_physical_expr::PhysicalExpr;
use hashbrown::raw::RawTable;
use rstest::*;
@@ -1529,15 +1521,8 @@ mod tests {
) -> Result<(Vec<String>, Vec<RecordBatch>)> {
let partition_count = 4;
- let (left_expr, right_expr) = on
- .iter()
- .map(|(l, r)| {
- (
- Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
- Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
- )
- })
- .unzip();
+ let (left_expr, right_expr) =
+ on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip();
let join = HashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
@@ -1588,8 +1573,8 @@ mod tests {
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) = join_collect(
@@ -1635,8 +1620,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) = partitioned_join_collect(
@@ -1679,8 +1664,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (columns, batches) =
@@ -1718,8 +1703,8 @@ mod tests {
("c2", &vec![80, 90, 70]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (columns, batches) =
@@ -1760,12 +1745,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1822,12 +1807,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1884,8 +1869,8 @@ mod tests {
("c2", &vec![80, 90, 70]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (columns, batches) =
@@ -1934,8 +1919,8 @@ mod tests {
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -2016,8 +2001,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap())
as _,
)];
let join = join(left, right, on, &JoinType::Left, false).unwrap();
@@ -2059,8 +2044,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let join = join(left, right, on, &JoinType::Full, false).unwrap();
@@ -2099,8 +2084,8 @@ mod tests {
);
let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2",
&vec![]));
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap())
as _,
)];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema,
None).unwrap());
@@ -2136,8 +2121,8 @@ mod tests {
);
let right = build_table_i32(("a2", &vec![]), ("b2", &vec![]), ("c2",
&vec![]));
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let schema = right.schema();
let right = Arc::new(MemoryExec::try_new(&[vec![right]], schema,
None).unwrap());
@@ -2177,8 +2162,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) = join_collect(
@@ -2221,8 +2206,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) = partitioned_join_collect(
@@ -2278,8 +2263,8 @@ mod tests {
let right = build_semi_anti_right_table();
// left_table left semi join right_table on left_table.b1 =
right_table.b2
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let join = join(left, right, on, &JoinType::LeftSemi, false)?;
@@ -2314,8 +2299,8 @@ mod tests {
// left_table left semi join right_table on left_table.b1 =
right_table.b2 and right_table.a2 != 10
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let column_indices = vec![ColumnIndex {
@@ -2401,8 +2386,8 @@ mod tests {
// left_table right semi join right_table on left_table.b1 =
right_table.b2
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let join = join(left, right, on, &JoinType::RightSemi, false)?;
@@ -2438,8 +2423,8 @@ mod tests {
// left_table right semi join right_table on left_table.b1 =
right_table.b2 on left_table.a1!=9
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let column_indices = vec![ColumnIndex {
@@ -2527,8 +2512,8 @@ mod tests {
let right = build_semi_anti_right_table();
// left_table left anti join right_table on left_table.b1 =
right_table.b2
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let join = join(left, right, on, &JoinType::LeftAnti, false)?;
@@ -2561,8 +2546,8 @@ mod tests {
let right = build_semi_anti_right_table();
// left_table left anti join right_table on left_table.b1 =
right_table.b2 and right_table.a2!=8
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let column_indices = vec![ColumnIndex {
@@ -2654,8 +2639,8 @@ mod tests {
let left = build_semi_anti_left_table();
let right = build_semi_anti_right_table();
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let join = join(left, right, on, &JoinType::RightAnti, false)?;
@@ -2689,8 +2674,8 @@ mod tests {
let right = build_semi_anti_right_table();
// left_table right anti join right_table on left_table.b1 =
right_table.b2 and left_table.a1!=13
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let column_indices = vec![ColumnIndex {
@@ -2797,8 +2782,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) =
@@ -2836,8 +2821,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (columns, batches) =
@@ -2876,8 +2861,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let join = join(left, right, on, &JoinType::Full, false)?;
@@ -2930,7 +2915,7 @@ mod tests {
);
// Join key column for both join sides
- let key_column = Column::new("a", 0);
+ let key_column: PhysicalExprRef = Arc::new(Column::new("a", 0)) as _;
let join_hash_map = JoinHashMap::new(hashmap_left, next);
@@ -2981,8 +2966,8 @@ mod tests {
);
let on = vec![(
// join on a=b so there are duplicate column names on unjoined
columns
- Column::new_with_schema("a", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b", &right.schema()).unwrap())
as _,
)];
let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -3045,8 +3030,8 @@ mod tests {
("c", &vec![7, 5, 6, 4]),
);
let on = vec![(
- Column::new_with_schema("a", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b", &right.schema()).unwrap())
as _,
)];
let filter = prepare_join_filter();
@@ -3086,8 +3071,8 @@ mod tests {
("c", &vec![7, 5, 6, 4]),
);
let on = vec![(
- Column::new_with_schema("a", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b", &right.schema()).unwrap())
as _,
)];
let filter = prepare_join_filter();
@@ -3130,8 +3115,8 @@ mod tests {
("c", &vec![7, 5, 6, 4]),
);
let on = vec![(
- Column::new_with_schema("a", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b", &right.schema()).unwrap())
as _,
)];
let filter = prepare_join_filter();
@@ -3173,8 +3158,8 @@ mod tests {
("c", &vec![7, 5, 6, 4]),
);
let on = vec![(
- Column::new_with_schema("a", &left.schema()).unwrap(),
- Column::new_with_schema("b", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as
_,
+ Arc::new(Column::new_with_schema("b", &right.schema()).unwrap())
as _,
)];
let filter = prepare_join_filter();
@@ -3223,8 +3208,8 @@ mod tests {
let right = Arc::new(MemoryExec::try_new(&[vec![batch]], schema,
None).unwrap());
let on = vec![(
- Column::new_with_schema("date", &left.schema()).unwrap(),
- Column::new_with_schema("date", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("date", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("date",
&right.schema()).unwrap()) as _,
)];
let join = join(left, right, on, &JoinType::Inner, false)?;
@@ -3261,8 +3246,8 @@ mod tests {
let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2",
&vec![]));
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b1", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema()).unwrap())
as _,
)];
let schema = right.schema();
let right = build_table_i32(("a2", &vec![]), ("b1", &vec![]), ("c2",
&vec![]));
@@ -3317,8 +3302,8 @@ mod tests {
("c2", &vec![0, 0, 0, 0, 0]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let join_types = vec![
@@ -3451,8 +3436,8 @@ mod tests {
("c2", &vec![14, 15]),
);
let on = vec![(
- Column::new_with_schema("a1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("a1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let join_types = vec![
@@ -3520,8 +3505,8 @@ mod tests {
.unwrap(),
);
let on = vec![(
- Column::new_with_schema("b1", &left_batch.schema())?,
- Column::new_with_schema("b2", &right_batch.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left_batch.schema())?) as
_,
+ Arc::new(Column::new_with_schema("b2", &right_batch.schema())?) as
_,
)];
let join_types = vec![
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join.rs
b/datafusion/physical-plan/src/joins/sort_merge_join.rs
index f6fdc6d77c..675e90fb63 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join.rs
@@ -30,7 +30,7 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use crate::expressions::{Column, PhysicalSortExpr};
+use crate::expressions::PhysicalSortExpr;
use crate::joins::utils::{
build_join_schema, calculate_join_output_ordering, check_join_is_valid,
estimate_join_statistics, partitioned_join_output_partitioning, JoinOn,
@@ -52,7 +52,9 @@ use datafusion_common::{
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::equivalence::join_equivalence_properties;
-use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortRequirement};
+use datafusion_physical_expr::{
+ EquivalenceProperties, PhysicalExprRef, PhysicalSortRequirement,
+};
use futures::{Stream, StreamExt};
@@ -120,11 +122,11 @@ impl SortMergeJoinExec {
.zip(sort_options.iter())
.map(|((l, r), sort_op)| {
let left = PhysicalSortExpr {
- expr: Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
+ expr: l.clone(),
options: *sort_op,
};
let right = PhysicalSortExpr {
- expr: Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
+ expr: r.clone(),
options: *sort_op,
};
(left, right)
@@ -189,7 +191,7 @@ impl SortMergeJoinExec {
}
/// Set of common columns used to join on
- pub fn on(&self) -> &[(Column, Column)] {
+ pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
@@ -236,16 +238,8 @@ impl ExecutionPlan for SortMergeJoinExec {
}
fn required_input_distribution(&self) -> Vec<Distribution> {
- let (left_expr, right_expr) = self
- .on
- .iter()
- .map(|(l, r)| {
- (
- Arc::new(l.clone()) as Arc<dyn PhysicalExpr>,
- Arc::new(r.clone()) as Arc<dyn PhysicalExpr>,
- )
- })
- .unzip();
+ let (left_expr, right_expr) =
+ self.on.iter().map(|(l, r)| (l.clone(), r.clone())).unzip();
vec![
Distribution::HashPartitioned(left_expr),
Distribution::HashPartitioned(right_expr),
@@ -483,7 +477,7 @@ struct StreamedBatch {
}
impl StreamedBatch {
- fn new(batch: RecordBatch, on_column: &[Column]) -> Self {
+ fn new(batch: RecordBatch, on_column: &[Arc<dyn PhysicalExpr>]) -> Self {
let join_arrays = join_arrays(&batch, on_column);
StreamedBatch {
batch,
@@ -547,7 +541,11 @@ struct BufferedBatch {
}
impl BufferedBatch {
- fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) ->
Self {
+ fn new(
+ batch: RecordBatch,
+ range: Range<usize>,
+ on_column: &[PhysicalExprRef],
+ ) -> Self {
let join_arrays = join_arrays(&batch, on_column);
// Estimation is calculated as
@@ -609,9 +607,9 @@ struct SMJStream {
/// The comparison result of current streamed row and buffered batches
pub current_ordering: Ordering,
/// Join key columns of streamed
- pub on_streamed: Vec<Column>,
+ pub on_streamed: Vec<PhysicalExprRef>,
/// Join key columns of buffered
- pub on_buffered: Vec<Column>,
+ pub on_buffered: Vec<PhysicalExprRef>,
/// Staging output array builders
pub output_record_batches: Vec<RecordBatch>,
/// Staging output size, including output batches and staging joined
results
@@ -736,8 +734,8 @@ impl SMJStream {
null_equals_null: bool,
streamed: SendableRecordBatchStream,
buffered: SendableRecordBatchStream,
- on_streamed: Vec<Column>,
- on_buffered: Vec<Column>,
+ on_streamed: Vec<Arc<dyn PhysicalExpr>>,
+ on_buffered: Vec<Arc<dyn PhysicalExpr>>,
join_type: JoinType,
batch_size: usize,
join_metrics: SortMergeJoinMetrics,
@@ -1218,10 +1216,14 @@ impl BufferedData {
}
/// Get join array refs of given batch and join columns
-fn join_arrays(batch: &RecordBatch, on_column: &[Column]) -> Vec<ArrayRef> {
+fn join_arrays(batch: &RecordBatch, on_column: &[PhysicalExprRef]) ->
Vec<ArrayRef> {
on_column
.iter()
- .map(|c| batch.column(c.index()).clone())
+ .map(|c| {
+ let num_rows = batch.num_rows();
+ let c = c.evaluate(batch).unwrap();
+ c.into_array(num_rows).unwrap()
+ })
.collect()
}
@@ -1582,8 +1584,8 @@ mod tests {
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Inner).await?;
@@ -1616,12 +1618,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1654,12 +1656,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1693,12 +1695,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1731,12 +1733,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
let (_, batches) = join_collect_with_options(
@@ -1783,12 +1785,12 @@ mod tests {
);
let on = vec![
(
- Column::new_with_schema("a1", &left.schema())?,
- Column::new_with_schema("a1", &right.schema())?,
+ Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("a1", &right.schema())?) as _,
),
(
- Column::new_with_schema("b2", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b2", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
),
];
@@ -1824,8 +1826,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Left).await?;
@@ -1856,8 +1858,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Right).await?;
@@ -1888,8 +1890,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema()).unwrap(),
- Column::new_with_schema("b2", &right.schema()).unwrap(),
+ Arc::new(Column::new_with_schema("b1", &left.schema()).unwrap())
as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema()).unwrap())
as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Full).await?;
@@ -1920,8 +1922,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::LeftAnti).await?;
@@ -1951,8 +1953,8 @@ mod tests {
("c2", &vec![70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::LeftSemi).await?;
@@ -1984,8 +1986,8 @@ mod tests {
);
let on = vec![(
// join on a=b so there are duplicate column names on unjoined
columns
- Column::new_with_schema("a", &left.schema())?,
- Column::new_with_schema("b", &right.schema())?,
+ Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Inner).await?;
@@ -2016,8 +2018,8 @@ mod tests {
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Inner).await?;
@@ -2048,8 +2050,8 @@ mod tests {
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b1", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Inner).await?;
@@ -2079,8 +2081,8 @@ mod tests {
("c2", &vec![50, 60, 70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Left).await?;
@@ -2115,8 +2117,8 @@ mod tests {
("c2", &vec![60, 70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Right).await?;
@@ -2159,8 +2161,8 @@ mod tests {
let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
let right = build_table_from_batches(vec![right_batch_1,
right_batch_2]);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Left).await?;
@@ -2208,8 +2210,8 @@ mod tests {
let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
let right = build_table_from_batches(vec![right_batch_1,
right_batch_2]);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Right).await?;
@@ -2257,8 +2259,8 @@ mod tests {
let left = build_table_from_batches(vec![left_batch_1, left_batch_2]);
let right = build_table_from_batches(vec![right_batch_1,
right_batch_2]);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let (_, batches) = join_collect(left, right, on,
JoinType::Full).await?;
@@ -2296,8 +2298,8 @@ mod tests {
("c2", &vec![50, 60, 70, 80, 90]),
);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];
@@ -2376,8 +2378,8 @@ mod tests {
let right =
build_table_from_batches(vec![right_batch_1, right_batch_2,
right_batch_3]);
let on = vec![(
- Column::new_with_schema("b1", &left.schema())?,
- Column::new_with_schema("b2", &right.schema())?,
+ Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];
diff --git a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
index 00950f0825..3f907930d6 100644
--- a/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
+++ b/datafusion/physical-plan/src/joins/symmetric_hash_join.rs
@@ -46,11 +46,11 @@ use crate::joins::utils::{
JoinHashMapType, JoinOn, StatefulStreamResult,
};
use crate::{
- expressions::{Column, PhysicalSortExpr},
+ expressions::PhysicalSortExpr,
joins::StreamJoinPartitionMode,
metrics::{ExecutionPlanMetricsSet, MetricsSet},
DisplayAs, DisplayFormatType, Distribution, EquivalenceProperties,
ExecutionPlan,
- Partitioning, PhysicalExpr, RecordBatchStream, SendableRecordBatchStream,
Statistics,
+ Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
use arrow::array::{
@@ -72,7 +72,7 @@ use
datafusion_physical_expr::equivalence::join_equivalence_properties;
use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph;
use ahash::RandomState;
-use datafusion_physical_expr::PhysicalSortRequirement;
+use datafusion_physical_expr::{PhysicalExprRef, PhysicalSortRequirement};
use futures::Stream;
use hashbrown::HashSet;
use parking_lot::Mutex;
@@ -171,7 +171,7 @@ pub struct SymmetricHashJoinExec {
/// Right side stream
pub(crate) right: Arc<dyn ExecutionPlan>,
/// Set of common columns used to join on
- pub(crate) on: Vec<(Column, Column)>,
+ pub(crate) on: Vec<(PhysicalExprRef, PhysicalExprRef)>,
/// Filters applied when finding matching rows
pub(crate) filter: Option<JoinFilter>,
/// How the join is performed
@@ -261,7 +261,7 @@ impl SymmetricHashJoinExec {
}
/// Set of common columns used to join on
- pub fn on(&self) -> &[(Column, Column)] {
+ pub fn on(&self) -> &[(PhysicalExprRef, PhysicalExprRef)] {
&self.on
}
@@ -367,7 +367,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
let (left_expr, right_expr) = self
.on
.iter()
- .map(|(l, r)| (Arc::new(l.clone()) as _,
Arc::new(r.clone()) as _))
+ .map(|(l, r)| (l.clone() as _, r.clone() as _))
.unzip();
vec![
Distribution::HashPartitioned(left_expr),
@@ -874,8 +874,8 @@ fn lookup_join_hashmap(
build_hashmap: &PruningJoinHashMap,
build_batch: &RecordBatch,
probe_batch: &RecordBatch,
- build_on: &[Column],
- probe_on: &[Column],
+ build_on: &[PhysicalExprRef],
+ probe_on: &[PhysicalExprRef],
random_state: &RandomState,
null_equals_null: bool,
hashes_buffer: &mut Vec<u64>,
@@ -952,7 +952,7 @@ pub struct OneSideHashJoiner {
/// Input record batch buffer
pub input_buffer: RecordBatch,
/// Columns from the side
- pub(crate) on: Vec<Column>,
+ pub(crate) on: Vec<PhysicalExprRef>,
/// Hashmap
pub(crate) hashmap: PruningJoinHashMap,
/// Reuse the hashes buffer
@@ -979,7 +979,11 @@ impl OneSideHashJoiner {
size += std::mem::size_of_val(&self.deleted_offset);
size
}
- pub fn new(build_side: JoinSide, on: Vec<Column>, schema: SchemaRef) ->
Self {
+ pub fn new(
+ build_side: JoinSide,
+ on: Vec<PhysicalExprRef>,
+ schema: SchemaRef,
+ ) -> Self {
Self {
build_side,
input_buffer: RecordBatch::new_empty(schema),
@@ -1447,8 +1451,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1515,8 +1519,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1569,8 +1573,8 @@ mod tests {
create_memory_table(left_partition, right_partition, vec![],
vec![])?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1621,8 +1625,8 @@ mod tests {
create_memory_table(left_partition, right_partition, vec![],
vec![])?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
experiment(left, right, None, join_type, on, task_ctx).await?;
Ok(())
@@ -1670,8 +1674,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1731,8 +1735,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1792,8 +1796,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1855,8 +1859,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1914,8 +1918,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -1981,8 +1985,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
@@ -2040,8 +2044,8 @@ mod tests {
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("lt1", left_schema)?,
@@ -2124,8 +2128,8 @@ mod tests {
let left_schema = &left_partition[0].schema();
let right_schema = &right_partition[0].schema();
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let left_sorted = vec![PhysicalSortExpr {
expr: col("li1", left_schema)?,
@@ -2217,8 +2221,8 @@ mod tests {
)?;
let on = vec![(
- Column::new_with_schema("lc1", left_schema)?,
- Column::new_with_schema("rc1", right_schema)?,
+ Arc::new(Column::new_with_schema("lc1", left_schema)?) as _,
+ Arc::new(Column::new_with_schema("rc1", right_schema)?) as _,
)];
let intermediate_schema = Schema::new(vec![
diff --git a/datafusion/physical-plan/src/joins/test_utils.rs
b/datafusion/physical-plan/src/joins/test_utils.rs
index 477e2de421..37faae8737 100644
--- a/datafusion/physical-plan/src/joins/test_utils.rs
+++ b/datafusion/physical-plan/src/joins/test_utils.rs
@@ -78,15 +78,9 @@ pub async fn partitioned_sym_join_with_filter(
) -> Result<Vec<RecordBatch>> {
let partition_count = 4;
- let left_expr = on
- .iter()
- .map(|(l, _)| Arc::new(l.clone()) as _)
- .collect::<Vec<_>>();
+ let left_expr = on.iter().map(|(l, _)| l.clone() as _).collect::<Vec<_>>();
- let right_expr = on
- .iter()
- .map(|(_, r)| Arc::new(r.clone()) as _)
- .collect::<Vec<_>>();
+ let right_expr = on.iter().map(|(_, r)| r.clone() as
_).collect::<Vec<_>>();
let join = SymmetricHashJoinExec::try_new(
Arc::new(RepartitionExec::try_new(
@@ -133,7 +127,7 @@ pub async fn partitioned_hash_join_with_filter(
let partition_count = 4;
let (left_expr, right_expr) = on
.iter()
- .map(|(l, r)| (Arc::new(l.clone()) as _, Arc::new(r.clone()) as _))
+ .map(|(l, r)| (l.clone() as _, r.clone() as _))
.unzip();
let join = Arc::new(HashJoinExec::try_new(
diff --git a/datafusion/physical-plan/src/joins/utils.rs
b/datafusion/physical-plan/src/joins/utils.rs
index cd987ab40d..e6e3f83fd7 100644
--- a/datafusion/physical-plan/src/joins/utils.rs
+++ b/datafusion/physical-plan/src/joins/utils.rs
@@ -45,11 +45,12 @@ use datafusion_common::{
use datafusion_expr::interval_arithmetic::Interval;
use datafusion_physical_expr::equivalence::add_offset_to_expr;
use datafusion_physical_expr::expressions::Column;
-use datafusion_physical_expr::utils::merge_vectors;
+use datafusion_physical_expr::utils::{collect_columns, merge_vectors};
use datafusion_physical_expr::{
- LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalSortExpr,
+ LexOrdering, LexOrderingRef, PhysicalExpr, PhysicalExprRef,
PhysicalSortExpr,
};
+use datafusion_common::tree_node::{Transformed, TreeNode};
use futures::future::{BoxFuture, Shared};
use futures::{ready, FutureExt};
use hashbrown::raw::RawTable;
@@ -377,9 +378,9 @@ impl fmt::Debug for JoinHashMap {
}
/// The on clause of the join, as vector of (left, right) columns.
-pub type JoinOn = Vec<(Column, Column)>;
+pub type JoinOn = Vec<(PhysicalExprRef, PhysicalExprRef)>;
/// Reference for JoinOn.
-pub type JoinOnRef<'a> = &'a [(Column, Column)];
+pub type JoinOnRef<'a> = &'a [(PhysicalExprRef, PhysicalExprRef)];
/// Checks whether the schemas "left" and "right" and columns "on" represent a
valid join.
/// They are valid whenever their columns' intersection equals the set `on`
@@ -405,12 +406,18 @@ pub fn check_join_is_valid(left: &Schema, right: &Schema,
on: JoinOnRef) -> Resu
fn check_join_set_is_valid(
left: &HashSet<Column>,
right: &HashSet<Column>,
- on: &[(Column, Column)],
+ on: &[(PhysicalExprRef, PhysicalExprRef)],
) -> Result<()> {
- let on_left = &on.iter().map(|on| on.0.clone()).collect::<HashSet<_>>();
+ let on_left = &on
+ .iter()
+ .flat_map(|on| collect_columns(&on.0))
+ .collect::<HashSet<_>>();
let left_missing = on_left.difference(left).collect::<HashSet<_>>();
- let on_right = &on.iter().map(|on| on.1.clone()).collect::<HashSet<_>>();
+ let on_right = &on
+ .iter()
+ .flat_map(|on| collect_columns(&on.1))
+ .collect::<HashSet<_>>();
let right_missing = on_right.difference(right).collect::<HashSet<_>>();
if !left_missing.is_empty() | !right_missing.is_empty() {
@@ -466,21 +473,41 @@ pub fn adjust_right_output_partitioning(
/// Replaces the right column (first index in the `on_column` tuple) with
/// the left column (zeroth index in the tuple) inside `right_ordering`.
fn replace_on_columns_of_right_ordering(
- on_columns: &[(Column, Column)],
+ on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
right_ordering: &mut [PhysicalSortExpr],
- left_columns_len: usize,
-) {
+) -> Result<()> {
for (left_col, right_col) in on_columns {
- let right_col =
- Column::new(right_col.name(), right_col.index() +
left_columns_len);
for item in right_ordering.iter_mut() {
- if let Some(col) = item.expr.as_any().downcast_ref::<Column>() {
- if right_col.eq(col) {
- item.expr = Arc::new(left_col.clone()) as _;
+ let new_expr = item.expr.clone().transform(&|e| {
+ if e.eq(right_col) {
+ Ok(Transformed::Yes(left_col.clone()))
+ } else {
+ Ok(Transformed::No(e))
}
- }
+ })?;
+ item.expr = new_expr;
}
}
+ Ok(())
+}
+
+fn offset_ordering(
+ ordering: LexOrderingRef,
+ join_type: &JoinType,
+ offset: usize,
+) -> Vec<PhysicalSortExpr> {
+ match join_type {
+ // In the case below, right ordering should be offseted with the left
+ // side length, since we append the right table to the left table.
+ JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Right =>
ordering
+ .iter()
+ .map(|sort_expr| PhysicalSortExpr {
+ expr: add_offset_to_expr(sort_expr.expr.clone(), offset),
+ options: sort_expr.options,
+ })
+ .collect(),
+ _ => ordering.to_vec(),
+ }
}
/// Calculate the output ordering of a given join operation.
@@ -488,35 +515,24 @@ pub fn calculate_join_output_ordering(
left_ordering: LexOrderingRef,
right_ordering: LexOrderingRef,
join_type: JoinType,
- on_columns: &[(Column, Column)],
+ on_columns: &[(PhysicalExprRef, PhysicalExprRef)],
left_columns_len: usize,
maintains_input_order: &[bool],
probe_side: Option<JoinSide>,
) -> Option<LexOrdering> {
- let mut right_ordering = match join_type {
- // In the case below, right ordering should be offseted with the left
- // side length, since we append the right table to the left table.
- JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full =>
{
- right_ordering
- .iter()
- .map(|sort_expr| PhysicalSortExpr {
- expr: add_offset_to_expr(sort_expr.expr.clone(),
left_columns_len),
- options: sort_expr.options,
- })
- .collect()
- }
- _ => right_ordering.to_vec(),
- };
let output_ordering = match maintains_input_order {
[true, false] => {
// Special case, we can prefix ordering of right side with the
ordering of left side.
if join_type == JoinType::Inner && probe_side ==
Some(JoinSide::Left) {
replace_on_columns_of_right_ordering(
on_columns,
- &mut right_ordering,
- left_columns_len,
- );
- merge_vectors(left_ordering, &right_ordering)
+ &mut right_ordering.to_vec(),
+ )
+ .ok()?;
+ merge_vectors(
+ left_ordering,
+ &offset_ordering(right_ordering, &join_type,
left_columns_len),
+ )
} else {
left_ordering.to_vec()
}
@@ -526,12 +542,15 @@ pub fn calculate_join_output_ordering(
if join_type == JoinType::Inner && probe_side ==
Some(JoinSide::Right) {
replace_on_columns_of_right_ordering(
on_columns,
- &mut right_ordering,
- left_columns_len,
- );
- merge_vectors(&right_ordering, left_ordering)
+ &mut right_ordering.to_vec(),
+ )
+ .ok()?;
+ merge_vectors(
+ &offset_ordering(right_ordering, &join_type,
left_columns_len),
+ left_ordering,
+ )
} else {
- right_ordering.to_vec()
+ offset_ordering(right_ordering, &join_type, left_columns_len)
}
}
// Doesn't maintain ordering, output ordering is None.
@@ -810,10 +829,19 @@ fn estimate_join_cardinality(
let (left_col_stats, right_col_stats) = on
.iter()
.map(|(left, right)| {
- (
- left_stats.column_statistics[left.index()].clone(),
- right_stats.column_statistics[right.index()].clone(),
- )
+ match (
+ left.as_any().downcast_ref::<Column>(),
+ right.as_any().downcast_ref::<Column>(),
+ ) {
+ (Some(left), Some(right)) => (
+ left_stats.column_statistics[left.index()].clone(),
+
right_stats.column_statistics[right.index()].clone(),
+ ),
+ _ => (
+ ColumnStatistics::new_unknown(),
+ ColumnStatistics::new_unknown(),
+ ),
+ }
})
.unzip::<_, _, Vec<_>, Vec<_>>();
@@ -1476,7 +1504,11 @@ mod tests {
use datafusion_common::stats::Precision::{Absent, Exact, Inexact};
use datafusion_common::{arrow_datafusion_err, arrow_err, ScalarValue};
- fn check(left: &[Column], right: &[Column], on: &[(Column, Column)]) ->
Result<()> {
+ fn check(
+ left: &[Column],
+ right: &[Column],
+ on: &[(PhysicalExprRef, PhysicalExprRef)],
+ ) -> Result<()> {
let left = left
.iter()
.map(|x| x.to_owned())
@@ -1492,7 +1524,10 @@ mod tests {
fn check_valid() -> Result<()> {
let left = vec![Column::new("a", 0), Column::new("b1", 1)];
let right = vec![Column::new("a", 0), Column::new("b2", 1)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
+ let on = &[(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("a", 0)) as _,
+ )];
check(&left, &right, on)?;
Ok(())
@@ -1502,7 +1537,10 @@ mod tests {
fn check_not_in_right() {
let left = vec![Column::new("a", 0), Column::new("b", 1)];
let right = vec![Column::new("b", 0)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
+ let on = &[(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("a", 0)) as _,
+ )];
assert!(check(&left, &right, on).is_err());
}
@@ -1544,7 +1582,10 @@ mod tests {
fn check_not_in_left() {
let left = vec![Column::new("b", 0)];
let right = vec![Column::new("a", 0)];
- let on = &[(Column::new("a", 0), Column::new("a", 0))];
+ let on = &[(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("a", 0)) as _,
+ )];
assert!(check(&left, &right, on).is_err());
}
@@ -1554,7 +1595,10 @@ mod tests {
// column "a" would appear both in left and right
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("a", 0), Column::new("b", 1)];
- let on = &[(Column::new("a", 0), Column::new("b", 1))];
+ let on = &[(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("b", 1)) as _,
+ )];
assert!(check(&left, &right, on).is_ok());
}
@@ -1563,7 +1607,10 @@ mod tests {
fn check_in_right() {
let left = vec![Column::new("a", 0), Column::new("c", 1)];
let right = vec![Column::new("b", 0)];
- let on = &[(Column::new("a", 0), Column::new("b", 0))];
+ let on = &[(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("b", 0)) as _,
+ )];
assert!(check(&left, &right, on).is_ok());
}
@@ -1835,7 +1882,10 @@ mod tests {
// We should also be able to use join_cardinality to get the same
results
let join_type = JoinType::Inner;
- let join_on = vec![(Column::new("a", 0), Column::new("b", 0))];
+ let join_on = vec![(
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("b", 0)) as _,
+ )];
let partial_join_stats = estimate_join_cardinality(
&join_type,
create_stats(Some(left_num_rows), left_col_stats.clone(),
false),
@@ -1957,8 +2007,14 @@ mod tests {
for (join_type, expected_num_rows) in cases {
let join_on = vec![
- (Column::new("a", 0), Column::new("c", 0)),
- (Column::new("b", 1), Column::new("d", 1)),
+ (
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("c", 0)) as _,
+ ),
+ (
+ Arc::new(Column::new("b", 1)) as _,
+ Arc::new(Column::new("d", 1)) as _,
+ ),
];
let partial_join_stats = estimate_join_cardinality(
@@ -2005,8 +2061,14 @@ mod tests {
];
let join_on = vec![
- (Column::new("a", 0), Column::new("c", 0)),
- (Column::new("x", 2), Column::new("y", 2)),
+ (
+ Arc::new(Column::new("a", 0)) as _,
+ Arc::new(Column::new("c", 0)) as _,
+ ),
+ (
+ Arc::new(Column::new("x", 2)) as _,
+ Arc::new(Column::new("y", 2)) as _,
+ ),
];
let cases = vec![
@@ -2071,7 +2133,10 @@ mod tests {
},
];
let join_type = JoinType::Inner;
- let on_columns = [(Column::new("b", 1), Column::new("x", 0))];
+ let on_columns = [(
+ Arc::new(Column::new("b", 1)) as _,
+ Arc::new(Column::new("x", 0)) as _,
+ )];
let left_columns_len = 5;
let maintains_input_orders = [[true, false], [false, true]];
let probe_sides = [Some(JoinSide::Left), Some(JoinSide::Right)];
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index c8468e1709..1d5ca59171 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1581,8 +1581,8 @@ message PhysicalColumn {
}
message JoinOn {
- PhysicalColumn left = 1;
- PhysicalColumn right = 2;
+ PhysicalExprNode left = 1;
+ PhysicalExprNode right = 2;
}
message EmptyExecNode {
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index a5582cc2dc..485dbd48b8 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -2244,9 +2244,9 @@ pub struct PhysicalColumn {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct JoinOn {
#[prost(message, optional, tag = "1")]
- pub left: ::core::option::Option<PhysicalColumn>,
+ pub left: ::core::option::Option<PhysicalExprNode>,
#[prost(message, optional, tag = "2")]
- pub right: ::core::option::Option<PhysicalColumn>,
+ pub right: ::core::option::Option<PhysicalExprNode>,
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index f39f885b78..d2961875d8 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -31,6 +31,7 @@ use datafusion::datasource::physical_plan::ParquetExec;
use datafusion::datasource::physical_plan::{AvroExec, CsvExec};
use datafusion::execution::runtime_env::RuntimeEnv;
use datafusion::execution::FunctionRegistry;
+use datafusion::physical_expr::PhysicalExprRef;
use datafusion::physical_plan::aggregates::{create_aggregate_expr,
AggregateMode};
use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy};
use datafusion::physical_plan::analyze::AnalyzeExec;
@@ -38,7 +39,7 @@ use
datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec;
use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use datafusion::physical_plan::empty::EmptyExec;
use datafusion::physical_plan::explain::ExplainExec;
-use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
+use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::insert::FileSinkExec;
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
@@ -64,6 +65,7 @@ use prost::Message;
use crate::common::str_to_byte;
use crate::common::{byte_to_string, proto_error};
+use crate::convert_required;
use crate::physical_plan::from_proto::{
parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs,
parse_protobuf_file_scan_config,
@@ -75,7 +77,6 @@ use crate::protobuf::repartition_exec_node::PartitionMethod;
use crate::protobuf::{
self, window_agg_exec_node, PhysicalPlanNode,
PhysicalSortExprNodeCollection,
};
-use crate::{convert_required, into_required};
use self::from_proto::parse_physical_window_expr;
@@ -506,12 +507,22 @@ impl AsExecutionPlan for PhysicalPlanNode {
runtime,
extension_codec,
)?;
- let on: Vec<(Column, Column)> = hashjoin
+ let left_schema = left.schema();
+ let right_schema = right.schema();
+ let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hashjoin
.on
.iter()
.map(|col| {
- let left = into_required!(col.left)?;
- let right = into_required!(col.right)?;
+ let left = parse_physical_expr(
+ &col.left.clone().unwrap(),
+ registry,
+ left_schema.as_ref(),
+ )?;
+ let right = parse_physical_expr(
+ &col.right.clone().unwrap(),
+ registry,
+ right_schema.as_ref(),
+ )?;
Ok((left, right))
})
.collect::<Result<_>>()?;
@@ -595,12 +606,22 @@ impl AsExecutionPlan for PhysicalPlanNode {
runtime,
extension_codec,
)?;
+ let left_schema = left.schema();
+ let right_schema = right.schema();
let on = sym_join
.on
.iter()
.map(|col| {
- let left = into_required!(col.left)?;
- let right = into_required!(col.right)?;
+ let left = parse_physical_expr(
+ &col.left.clone().unwrap(),
+ registry,
+ left_schema.as_ref(),
+ )?;
+ let right = parse_physical_expr(
+ &col.right.clone().unwrap(),
+ registry,
+ right_schema.as_ref(),
+ )?;
Ok((left, right))
})
.collect::<Result<_>>()?;
@@ -647,7 +668,6 @@ impl AsExecutionPlan for PhysicalPlanNode {
})
.map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
- let left_schema = left.schema();
let left_sort_exprs = parse_physical_sort_exprs(
&sym_join.left_sort_exprs,
registry,
@@ -659,7 +679,6 @@ impl AsExecutionPlan for PhysicalPlanNode {
Some(left_sort_exprs)
};
- let right_schema = right.schema();
let right_sort_exprs = parse_physical_sort_exprs(
&sym_join.right_sort_exprs,
registry,
@@ -1144,17 +1163,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
let on: Vec<protobuf::JoinOn> = exec
.on()
.iter()
- .map(|tuple| protobuf::JoinOn {
- left: Some(protobuf::PhysicalColumn {
- name: tuple.0.name().to_string(),
- index: tuple.0.index() as u32,
- }),
- right: Some(protobuf::PhysicalColumn {
- name: tuple.1.name().to_string(),
- index: tuple.1.index() as u32,
- }),
+ .map(|tuple| {
+ let l = tuple.0.to_owned().try_into()?;
+ let r = tuple.1.to_owned().try_into()?;
+ Ok::<_, DataFusionError>(protobuf::JoinOn {
+ left: Some(l),
+ right: Some(r),
+ })
})
- .collect();
+ .collect::<Result<_>>()?;
let join_type: protobuf::JoinType =
exec.join_type().to_owned().into();
let filter = exec
.filter()
@@ -1214,17 +1231,15 @@ impl AsExecutionPlan for PhysicalPlanNode {
let on = exec
.on()
.iter()
- .map(|tuple| protobuf::JoinOn {
- left: Some(protobuf::PhysicalColumn {
- name: tuple.0.name().to_string(),
- index: tuple.0.index() as u32,
- }),
- right: Some(protobuf::PhysicalColumn {
- name: tuple.1.name().to_string(),
- index: tuple.1.index() as u32,
- }),
+ .map(|tuple| {
+ let l = tuple.0.to_owned().try_into()?;
+ let r = tuple.1.to_owned().try_into()?;
+ Ok::<_, DataFusionError>(protobuf::JoinOn {
+ left: Some(l),
+ right: Some(r),
+ })
})
- .collect();
+ .collect::<Result<_>>()?;
let join_type: protobuf::JoinType =
exec.join_type().to_owned().into();
let filter = exec
.filter()
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index eba3db298f..f2f1b0ea0d 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -191,8 +191,8 @@ fn roundtrip_hash_join() -> Result<()> {
let schema_left = Schema::new(vec![field_a.clone()]);
let schema_right = Schema::new(vec![field_a]);
let on = vec![(
- Column::new("col", schema_left.index_of("col")?),
- Column::new("col", schema_right.index_of("col")?),
+ Arc::new(Column::new("col", schema_left.index_of("col")?)) as _,
+ Arc::new(Column::new("col", schema_right.index_of("col")?)) as _,
)];
let schema_left = Arc::new(schema_left);
@@ -916,8 +916,8 @@ fn roundtrip_sym_hash_join() -> Result<()> {
let schema_left = Schema::new(vec![field_a.clone()]);
let schema_right = Schema::new(vec![field_a]);
let on = vec![(
- Column::new("col", schema_left.index_of("col")?),
- Column::new("col", schema_right.index_of("col")?),
+ Arc::new(Column::new("col", schema_left.index_of("col")?)) as _,
+ Arc::new(Column::new("col", schema_right.index_of("col")?)) as _,
)];
let schema_left = Arc::new(schema_left);