This is an automated email from the ASF dual-hosted git repository.
jeffreyvo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 2c9f42bae7 feat: Support SortMergeJoin proto serde (#17296)
2c9f42bae7 is described below
commit 2c9f42bae7b76484b3a49077c0a5c70d36a32817
Author: Marko Milenković <[email protected]>
AuthorDate: Sun Aug 24 15:48:36 2025 +0100
feat: Support SortMergeJoin proto serde (#17296)
* Implement ser/de part of SortMergeJoin
* add round trip tests for sort merge join
* add filter test to roundtrip
---
datafusion/proto/proto/datafusion.proto | 11 ++
datafusion/proto/src/generated/pbjson.rs | 214 +++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 21 +-
datafusion/proto/src/physical_plan/mod.rs | 211 +++++++++++++++++++-
.../proto/tests/cases/roundtrip_physical_plan.rs | 88 ++++++++-
5 files changed, 540 insertions(+), 5 deletions(-)
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 81d5016810..e5f209783a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -729,6 +729,7 @@ message PhysicalPlanNode {
JsonScanExecNode json_scan = 31;
CooperativeExecNode cooperative = 32;
GenerateSeriesNode generate_series = 33;
+ SortMergeJoinExecNode sort_merge_join = 34;
}
}
@@ -1343,3 +1344,13 @@ message GenerateSeriesNode {
GenerateSeriesArgsDate date_args = 6;
}
}
+
+message SortMergeJoinExecNode {
+ PhysicalPlanNode left = 1;
+ PhysicalPlanNode right = 2;
+ repeated JoinOn on = 3;
+ datafusion_common.JoinType join_type = 4;
+ JoinFilter filter = 5;
+ repeated SortExprNode sort_options = 6;
+ datafusion_common.NullEquality null_equality = 7;
+}
\ No newline at end of file
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index d274d5c518..5fbfdd5f54 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -16799,6 +16799,9 @@ impl serde::Serialize for PhysicalPlanNode {
physical_plan_node::PhysicalPlanType::GenerateSeries(v) => {
struct_ser.serialize_field("generateSeries", v)?;
}
+ physical_plan_node::PhysicalPlanType::SortMergeJoin(v) => {
+ struct_ser.serialize_field("sortMergeJoin", v)?;
+ }
}
}
struct_ser.end()
@@ -16860,6 +16863,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"cooperative",
"generate_series",
"generateSeries",
+ "sort_merge_join",
+ "sortMergeJoin",
];
#[allow(clippy::enum_variant_names)]
@@ -16896,6 +16901,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
JsonScan,
Cooperative,
GenerateSeries,
+ SortMergeJoin,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -16949,6 +16955,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"jsonScan" | "json_scan" =>
Ok(GeneratedField::JsonScan),
"cooperative" => Ok(GeneratedField::Cooperative),
"generateSeries" | "generate_series" =>
Ok(GeneratedField::GenerateSeries),
+ "sortMergeJoin" | "sort_merge_join" =>
Ok(GeneratedField::SortMergeJoin),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -17193,6 +17200,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode
{
return
Err(serde::de::Error::duplicate_field("generateSeries"));
}
physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GenerateSeries)
+;
+ }
+ GeneratedField::SortMergeJoin => {
+ if physical_plan_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sortMergeJoin"));
+ }
+ physical_plan_type__ =
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortMergeJoin)
;
}
}
@@ -20535,6 +20549,206 @@ impl<'de> serde::Deserialize<'de> for
SortExprNodeCollection {
deserializer.deserialize_struct("datafusion.SortExprNodeCollection",
FIELDS, GeneratedVisitor)
}
}
+impl serde::Serialize for SortMergeJoinExecNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.left.is_some() {
+ len += 1;
+ }
+ if self.right.is_some() {
+ len += 1;
+ }
+ if !self.on.is_empty() {
+ len += 1;
+ }
+ if self.join_type != 0 {
+ len += 1;
+ }
+ if self.filter.is_some() {
+ len += 1;
+ }
+ if !self.sort_options.is_empty() {
+ len += 1;
+ }
+ if self.null_equality != 0 {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.SortMergeJoinExecNode", len)?;
+ if let Some(v) = self.left.as_ref() {
+ struct_ser.serialize_field("left", v)?;
+ }
+ if let Some(v) = self.right.as_ref() {
+ struct_ser.serialize_field("right", v)?;
+ }
+ if !self.on.is_empty() {
+ struct_ser.serialize_field("on", &self.on)?;
+ }
+ if self.join_type != 0 {
+ let v =
super::datafusion_common::JoinType::try_from(self.join_type)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.join_type)))?;
+ struct_ser.serialize_field("joinType", &v)?;
+ }
+ if let Some(v) = self.filter.as_ref() {
+ struct_ser.serialize_field("filter", v)?;
+ }
+ if !self.sort_options.is_empty() {
+ struct_ser.serialize_field("sortOptions", &self.sort_options)?;
+ }
+ if self.null_equality != 0 {
+ let v =
super::datafusion_common::NullEquality::try_from(self.null_equality)
+ .map_err(|_| serde::ser::Error::custom(format!("Invalid
variant {}", self.null_equality)))?;
+ struct_ser.serialize_field("nullEquality", &v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for SortMergeJoinExecNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "left",
+ "right",
+ "on",
+ "join_type",
+ "joinType",
+ "filter",
+ "sort_options",
+ "sortOptions",
+ "null_equality",
+ "nullEquality",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Left,
+ Right,
+ On,
+ JoinType,
+ Filter,
+ SortOptions,
+ NullEquality,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "left" => Ok(GeneratedField::Left),
+ "right" => Ok(GeneratedField::Right),
+ "on" => Ok(GeneratedField::On),
+ "joinType" | "join_type" =>
Ok(GeneratedField::JoinType),
+ "filter" => Ok(GeneratedField::Filter),
+ "sortOptions" | "sort_options" =>
Ok(GeneratedField::SortOptions),
+ "nullEquality" | "null_equality" =>
Ok(GeneratedField::NullEquality),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = SortMergeJoinExecNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.SortMergeJoinExecNode")
+ }
+
+ fn visit_map<V>(self, mut map_: V) ->
std::result::Result<SortMergeJoinExecNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut left__ = None;
+ let mut right__ = None;
+ let mut on__ = None;
+ let mut join_type__ = None;
+ let mut filter__ = None;
+ let mut sort_options__ = None;
+ let mut null_equality__ = None;
+ while let Some(k) = map_.next_key()? {
+ match k {
+ GeneratedField::Left => {
+ if left__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("left"));
+ }
+ left__ = map_.next_value()?;
+ }
+ GeneratedField::Right => {
+ if right__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("right"));
+ }
+ right__ = map_.next_value()?;
+ }
+ GeneratedField::On => {
+ if on__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("on"));
+ }
+ on__ = Some(map_.next_value()?);
+ }
+ GeneratedField::JoinType => {
+ if join_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("joinType"));
+ }
+ join_type__ =
Some(map_.next_value::<super::datafusion_common::JoinType>()? as i32);
+ }
+ GeneratedField::Filter => {
+ if filter__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("filter"));
+ }
+ filter__ = map_.next_value()?;
+ }
+ GeneratedField::SortOptions => {
+ if sort_options__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("sortOptions"));
+ }
+ sort_options__ = Some(map_.next_value()?);
+ }
+ GeneratedField::NullEquality => {
+ if null_equality__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("nullEquality"));
+ }
+ null_equality__ =
Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32);
+ }
+ }
+ }
+ Ok(SortMergeJoinExecNode {
+ left: left__,
+ right: right__,
+ on: on__.unwrap_or_default(),
+ join_type: join_type__.unwrap_or_default(),
+ filter: filter__,
+ sort_options: sort_options__.unwrap_or_default(),
+ null_equality: null_equality__.unwrap_or_default(),
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.SortMergeJoinExecNode",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for SortNode {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 8118adf323..65c807e816 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1053,7 +1053,7 @@ pub mod table_reference {
pub struct PhysicalPlanNode {
#[prost(
oneof = "physical_plan_node::PhysicalPlanType",
- tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33"
+ tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34"
)]
pub physical_plan_type:
::core::option::Option<physical_plan_node::PhysicalPlanType>,
}
@@ -1127,6 +1127,8 @@ pub mod physical_plan_node {
Cooperative(::prost::alloc::boxed::Box<super::CooperativeExecNode>),
#[prost(message, tag = "33")]
GenerateSeries(super::GenerateSeriesNode),
+ #[prost(message, tag = "34")]
+
SortMergeJoin(::prost::alloc::boxed::Box<super::SortMergeJoinExecNode>),
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
@@ -2029,6 +2031,23 @@ pub mod generate_series_node {
DateArgs(super::GenerateSeriesArgsDate),
}
}
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SortMergeJoinExecNode {
+ #[prost(message, optional, boxed, tag = "1")]
+ pub left:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, optional, boxed, tag = "2")]
+ pub right:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, repeated, tag = "3")]
+ pub on: ::prost::alloc::vec::Vec<JoinOn>,
+ #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")]
+ pub join_type: i32,
+ #[prost(message, optional, tag = "5")]
+ pub filter: ::core::option::Option<JoinFilter>,
+ #[prost(message, repeated, tag = "6")]
+ pub sort_options: ::prost::alloc::vec::Vec<SortExprNode>,
+ #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")]
+ pub null_equality: i32,
+}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord,
::prost::Enumeration)]
#[repr(i32)]
pub enum WindowFrameUnits {
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index fb86e38055..f1e82841d0 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -34,7 +34,8 @@ use
crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
use crate::protobuf::physical_expr_node::ExprType;
use crate::protobuf::physical_plan_node::PhysicalPlanType;
use crate::protobuf::{
- self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest,
+ self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest,
SortExprNode,
+ SortMergeJoinExecNode,
};
use crate::{convert_required, into_required};
@@ -74,7 +75,8 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
use datafusion::physical_plan::joins::{
- CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode,
SymmetricHashJoinExec,
+ CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec,
StreamJoinPartitionMode,
+ SymmetricHashJoinExec,
};
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
@@ -294,6 +296,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
PhysicalPlanType::GenerateSeries(generate_series) => {
self.try_into_generate_series_physical_plan(generate_series)
}
+ PhysicalPlanType::SortMergeJoin(sort_join) => {
+ self.try_into_sort_join(sort_join, ctx, runtime,
extension_codec)
+ }
}
}
@@ -363,6 +368,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
);
}
+ if let Some(exec) = plan.downcast_ref::<SortMergeJoinExec>() {
+ return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec(
+ exec,
+ extension_codec,
+ );
+ }
+
if let Some(exec) = plan.downcast_ref::<CrossJoinExec>() {
return protobuf::PhysicalPlanNode::try_from_cross_join_exec(
exec,
@@ -1752,6 +1764,117 @@ impl protobuf::PhysicalPlanNode {
protobuf::GenerateSeriesName::GsRange => "range",
}
}
+ fn try_into_sort_join(
+ &self,
+ sort_join: &SortMergeJoinExecNode,
+ ctx: &SessionContext,
+ runtime: &RuntimeEnv,
+ extension_codec: &dyn PhysicalExtensionCodec,
+ ) -> Result<Arc<dyn ExecutionPlan>> {
+ let left = into_physical_plan(&sort_join.left, ctx, runtime,
extension_codec)?;
+ let left_schema = left.schema();
+ let right = into_physical_plan(&sort_join.right, ctx, runtime,
extension_codec)?;
+ let right_schema = right.schema();
+
+ let filter = sort_join
+ .filter
+ .as_ref()
+ .map(|f| {
+ let schema = f
+ .schema
+ .as_ref()
+ .ok_or_else(|| proto_error("Missing JoinFilter schema"))?
+ .try_into()?;
+
+ let expression = parse_physical_expr(
+ f.expression.as_ref().ok_or_else(|| {
+ proto_error("Unexpected empty filter expression")
+ })?,
+ ctx,
+ &schema,
+ extension_codec,
+ )?;
+ let column_indices = f
+ .column_indices
+ .iter()
+ .map(|i| {
+ let side =
+ protobuf::JoinSide::try_from(i.side).map_err(|_| {
+ proto_error(format!(
+ "Received a SortMergeJoinExecNode message
with JoinSide in Filter {}",
+ i.side
+ ))
+ })?;
+
+ Ok(ColumnIndex {
+ index: i.index as usize,
+ side: side.into(),
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(JoinFilter::new(
+ expression,
+ column_indices,
+ Arc::new(schema),
+ ))
+ })
+ .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+ let join_type =
+ protobuf::JoinType::try_from(sort_join.join_type).map_err(|_| {
+ proto_error(format!(
+ "Received a SortMergeJoinExecNode message with unknown
JoinType {}",
+ sort_join.join_type
+ ))
+ })?;
+
+ let null_equality =
protobuf::NullEquality::try_from(sort_join.null_equality)
+ .map_err(|_| {
+ proto_error(format!(
+ "Received a SortMergeJoinExecNode message with unknown
NullEquality {}",
+ sort_join.null_equality
+ ))
+ })?;
+
+ let sort_options = sort_join
+ .sort_options
+ .iter()
+ .map(|e| SortOptions {
+ descending: !e.asc,
+ nulls_first: e.nulls_first,
+ })
+ .collect();
+ let on = sort_join
+ .on
+ .iter()
+ .map(|col| {
+ let left = parse_physical_expr(
+ &col.left.clone().unwrap(),
+ ctx,
+ left_schema.as_ref(),
+ extension_codec,
+ )?;
+ let right = parse_physical_expr(
+ &col.right.clone().unwrap(),
+ ctx,
+ right_schema.as_ref(),
+ extension_codec,
+ )?;
+ Ok((left, right))
+ })
+ .collect::<Result<_>>()?;
+
+ Ok(Arc::new(SortMergeJoinExec::try_new(
+ left,
+ right,
+ on,
+ filter,
+ join_type.into(),
+ sort_options,
+ null_equality.into(),
+ )?))
+ }
fn try_into_generate_series_physical_plan(
&self,
@@ -2154,6 +2277,90 @@ impl protobuf::PhysicalPlanNode {
})
}
+ fn try_from_sort_merge_join_exec(
+ exec: &SortMergeJoinExec,
+ extension_codec: &dyn PhysicalExtensionCodec,
+ ) -> Result<Self> {
+ let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.left().to_owned(),
+ extension_codec,
+ )?;
+ let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.right().to_owned(),
+ extension_codec,
+ )?;
+ let on = exec
+ .on()
+ .iter()
+ .map(|tuple| {
+ let l = serialize_physical_expr(&tuple.0, extension_codec)?;
+ let r = serialize_physical_expr(&tuple.1, extension_codec)?;
+ Ok::<_, DataFusionError>(protobuf::JoinOn {
+ left: Some(l),
+ right: Some(r),
+ })
+ })
+ .collect::<Result<_>>()?;
+ let join_type: protobuf::JoinType = exec.join_type().to_owned().into();
+ let null_equality: protobuf::NullEquality =
exec.null_equality().into();
+ let filter = exec
+ .filter()
+ .as_ref()
+ .map(|f| {
+ let expression =
+ serialize_physical_expr(f.expression(), extension_codec)?;
+ let column_indices = f
+ .column_indices()
+ .iter()
+ .map(|i| {
+ let side: protobuf::JoinSide =
i.side.to_owned().into();
+ protobuf::ColumnIndex {
+ index: i.index as u32,
+ side: side.into(),
+ }
+ })
+ .collect();
+ let schema = f.schema().as_ref().try_into()?;
+ Ok(protobuf::JoinFilter {
+ expression: Some(expression),
+ column_indices,
+ schema: Some(schema),
+ })
+ })
+ .map_or(Ok(None), |v: Result<protobuf::JoinFilter>| v.map(Some))?;
+
+ let sort_options = exec
+ .sort_options()
+ .iter()
+ .map(
+ |SortOptions {
+ descending,
+ nulls_first,
+ }| {
+ SortExprNode {
+ expr: None,
+ asc: !*descending,
+ nulls_first: *nulls_first,
+ }
+ },
+ )
+ .collect();
+
+ Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new(
+ protobuf::SortMergeJoinExecNode {
+ left: Some(Box::new(left)),
+ right: Some(Box::new(right)),
+ on,
+ join_type: join_type.into(),
+ null_equality: null_equality.into(),
+ filter,
+ sort_options,
+ },
+ ))),
+ })
+ }
+
fn try_from_cross_join_exec(
exec: &CrossJoinExec,
extension_codec: &dyn PhysicalExtensionCodec,
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 1547b7087d..86ad54d3f1 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -75,8 +75,8 @@ use datafusion::physical_plan::expressions::{
};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::joins::{
- HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
- SymmetricHashJoinExec,
+ HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
+ StreamJoinPartitionMode, SymmetricHashJoinExec,
};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::placeholder_row::PlaceholderRowExec;
@@ -2125,3 +2125,87 @@ async fn analyze_roundtrip_unoptimized() -> Result<()> {
physical_planner.optimize_physical_plan(unoptimized, &session_state, |_,
_| {})?;
Ok(())
}
+
+#[test]
+fn roundtrip_sort_merge_join() -> Result<()> {
+ let field_a = Field::new("col_a", DataType::Int64, false);
+ let field_b = Field::new("col_b", DataType::Int64, false);
+ let schema_left = Schema::new(vec![field_a.clone()]);
+ let schema_right = Schema::new(vec![field_b.clone()]);
+ let on = vec![(
+ Arc::new(Column::new("col_a", schema_left.index_of("col_a")?)) as _,
+ Arc::new(Column::new("col_b", schema_right.index_of("col_b")?)) as _,
+ )];
+
+ let filter = datafusion::physical_plan::joins::utils::JoinFilter::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("col_a", 1)),
+ Operator::Gt,
+ Arc::new(Column::new("col_b", 0)),
+ )),
+ vec![
+ datafusion::physical_plan::joins::utils::ColumnIndex {
+ index: 0,
+ side: datafusion_common::JoinSide::Left,
+ },
+ datafusion::physical_plan::joins::utils::ColumnIndex {
+ index: 0,
+ side: datafusion_common::JoinSide::Right,
+ },
+ ],
+ Arc::new(Schema::new(vec![field_a, field_b])),
+ );
+
+ let schema_left = Arc::new(schema_left);
+ let schema_right = Arc::new(schema_right);
+ for filter in [None, Some(filter)] {
+ for join_type in [
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftAnti,
+ JoinType::RightAnti,
+ JoinType::LeftSemi,
+ JoinType::RightSemi,
+ ] {
+ roundtrip_test(Arc::new(SortMergeJoinExec::try_new(
+ Arc::new(EmptyExec::new(schema_left.clone())),
+ Arc::new(EmptyExec::new(schema_right.clone())),
+ on.clone(),
+ filter.clone(),
+ join_type,
+ vec![Default::default()],
+ NullEquality::NullEqualsNothing,
+ )?))?;
+ }
+ }
+ Ok(())
+}
+
+#[tokio::test]
+async fn roundtrip_logical_plan_sort_merge_join() -> Result<()> {
+ let ctx = SessionContext::new();
+ ctx.register_csv(
+ "t0",
+ "tests/testdata/test.csv",
+ datafusion::prelude::CsvReadOptions::default().has_header(true),
+ )
+ .await?;
+ ctx.register_csv(
+ "t1",
+ "tests/testdata/test.csv",
+ datafusion::prelude::CsvReadOptions::default().has_header(true),
+ )
+ .await?;
+
+ ctx.sql("SET datafusion.optimizer.prefer_hash_join = false")
+ .await?
+ .show()
+ .await?;
+
+ let query = "SELECT t1.* FROM t0 join t1 on t0.a = t1.a";
+ let plan = ctx.sql(query).await?.create_physical_plan().await?;
+
+ roundtrip_test(plan)
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]