This is an automated email from the ASF dual-hosted git repository.

jeffreyvo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c9f42bae7 feat: Support SortMergeJoin proto serde (#17296)
2c9f42bae7 is described below

commit 2c9f42bae7b76484b3a49077c0a5c70d36a32817
Author: Marko Milenković <[email protected]>
AuthorDate: Sun Aug 24 15:48:36 2025 +0100

    feat: Support SortMergeJoin proto serde (#17296)
    
    * Implement ser/de part of SortMergeJoin
    
    * add round trip tests for sort merge join
    
    * add filter test to roundtrip
---
 datafusion/proto/proto/datafusion.proto            |  11 ++
 datafusion/proto/src/generated/pbjson.rs           | 214 +++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  21 +-
 datafusion/proto/src/physical_plan/mod.rs          | 211 +++++++++++++++++++-
 .../proto/tests/cases/roundtrip_physical_plan.rs   |  88 ++++++++-
 5 files changed, 540 insertions(+), 5 deletions(-)

diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 81d5016810..e5f209783a 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -729,6 +729,7 @@ message PhysicalPlanNode {
     JsonScanExecNode json_scan = 31;
     CooperativeExecNode cooperative = 32;
     GenerateSeriesNode generate_series = 33;
+    SortMergeJoinExecNode sort_merge_join = 34;
   }
 }
 
@@ -1343,3 +1344,13 @@ message GenerateSeriesNode {
         GenerateSeriesArgsDate date_args = 6;
     }
 }
+
+message SortMergeJoinExecNode {
+  PhysicalPlanNode left = 1;
+  PhysicalPlanNode right = 2;
+  repeated JoinOn on = 3;
+  datafusion_common.JoinType join_type = 4;
+  JoinFilter filter = 5;
+  repeated SortExprNode sort_options = 6;
+  datafusion_common.NullEquality null_equality = 7;
+}
\ No newline at end of file
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index d274d5c518..5fbfdd5f54 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -16799,6 +16799,9 @@ impl serde::Serialize for PhysicalPlanNode {
                 physical_plan_node::PhysicalPlanType::GenerateSeries(v) => {
                     struct_ser.serialize_field("generateSeries", v)?;
                 }
+                physical_plan_node::PhysicalPlanType::SortMergeJoin(v) => {
+                    struct_ser.serialize_field("sortMergeJoin", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -16860,6 +16863,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             "cooperative",
             "generate_series",
             "generateSeries",
+            "sort_merge_join",
+            "sortMergeJoin",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -16896,6 +16901,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             JsonScan,
             Cooperative,
             GenerateSeries,
+            SortMergeJoin,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -16949,6 +16955,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
                             "jsonScan" | "json_scan" => 
Ok(GeneratedField::JsonScan),
                             "cooperative" => Ok(GeneratedField::Cooperative),
                             "generateSeries" | "generate_series" => 
Ok(GeneratedField::GenerateSeries),
+                            "sortMergeJoin" | "sort_merge_join" => 
Ok(GeneratedField::SortMergeJoin),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -17193,6 +17200,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode 
{
                                 return 
Err(serde::de::Error::duplicate_field("generateSeries"));
                             }
                             physical_plan_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::GenerateSeries)
+;
+                        }
+                        GeneratedField::SortMergeJoin => {
+                            if physical_plan_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("sortMergeJoin"));
+                            }
+                            physical_plan_type__ = 
map_.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortMergeJoin)
 ;
                         }
                     }
@@ -20535,6 +20549,206 @@ impl<'de> serde::Deserialize<'de> for 
SortExprNodeCollection {
         deserializer.deserialize_struct("datafusion.SortExprNodeCollection", 
FIELDS, GeneratedVisitor)
     }
 }
+impl serde::Serialize for SortMergeJoinExecNode {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.left.is_some() {
+            len += 1;
+        }
+        if self.right.is_some() {
+            len += 1;
+        }
+        if !self.on.is_empty() {
+            len += 1;
+        }
+        if self.join_type != 0 {
+            len += 1;
+        }
+        if self.filter.is_some() {
+            len += 1;
+        }
+        if !self.sort_options.is_empty() {
+            len += 1;
+        }
+        if self.null_equality != 0 {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.SortMergeJoinExecNode", len)?;
+        if let Some(v) = self.left.as_ref() {
+            struct_ser.serialize_field("left", v)?;
+        }
+        if let Some(v) = self.right.as_ref() {
+            struct_ser.serialize_field("right", v)?;
+        }
+        if !self.on.is_empty() {
+            struct_ser.serialize_field("on", &self.on)?;
+        }
+        if self.join_type != 0 {
+            let v = 
super::datafusion_common::JoinType::try_from(self.join_type)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.join_type)))?;
+            struct_ser.serialize_field("joinType", &v)?;
+        }
+        if let Some(v) = self.filter.as_ref() {
+            struct_ser.serialize_field("filter", v)?;
+        }
+        if !self.sort_options.is_empty() {
+            struct_ser.serialize_field("sortOptions", &self.sort_options)?;
+        }
+        if self.null_equality != 0 {
+            let v = 
super::datafusion_common::NullEquality::try_from(self.null_equality)
+                .map_err(|_| serde::ser::Error::custom(format!("Invalid 
variant {}", self.null_equality)))?;
+            struct_ser.serialize_field("nullEquality", &v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for SortMergeJoinExecNode {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "left",
+            "right",
+            "on",
+            "join_type",
+            "joinType",
+            "filter",
+            "sort_options",
+            "sortOptions",
+            "null_equality",
+            "nullEquality",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Left,
+            Right,
+            On,
+            JoinType,
+            Filter,
+            SortOptions,
+            NullEquality,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "left" => Ok(GeneratedField::Left),
+                            "right" => Ok(GeneratedField::Right),
+                            "on" => Ok(GeneratedField::On),
+                            "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
+                            "filter" => Ok(GeneratedField::Filter),
+                            "sortOptions" | "sort_options" => 
Ok(GeneratedField::SortOptions),
+                            "nullEquality" | "null_equality" => 
Ok(GeneratedField::NullEquality),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = SortMergeJoinExecNode;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.SortMergeJoinExecNode")
+            }
+
+            fn visit_map<V>(self, mut map_: V) -> 
std::result::Result<SortMergeJoinExecNode, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut left__ = None;
+                let mut right__ = None;
+                let mut on__ = None;
+                let mut join_type__ = None;
+                let mut filter__ = None;
+                let mut sort_options__ = None;
+                let mut null_equality__ = None;
+                while let Some(k) = map_.next_key()? {
+                    match k {
+                        GeneratedField::Left => {
+                            if left__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("left"));
+                            }
+                            left__ = map_.next_value()?;
+                        }
+                        GeneratedField::Right => {
+                            if right__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("right"));
+                            }
+                            right__ = map_.next_value()?;
+                        }
+                        GeneratedField::On => {
+                            if on__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("on"));
+                            }
+                            on__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::JoinType => {
+                            if join_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("joinType"));
+                            }
+                            join_type__ = 
Some(map_.next_value::<super::datafusion_common::JoinType>()? as i32);
+                        }
+                        GeneratedField::Filter => {
+                            if filter__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("filter"));
+                            }
+                            filter__ = map_.next_value()?;
+                        }
+                        GeneratedField::SortOptions => {
+                            if sort_options__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("sortOptions"));
+                            }
+                            sort_options__ = Some(map_.next_value()?);
+                        }
+                        GeneratedField::NullEquality => {
+                            if null_equality__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nullEquality"));
+                            }
+                            null_equality__ = 
Some(map_.next_value::<super::datafusion_common::NullEquality>()? as i32);
+                        }
+                    }
+                }
+                Ok(SortMergeJoinExecNode {
+                    left: left__,
+                    right: right__,
+                    on: on__.unwrap_or_default(),
+                    join_type: join_type__.unwrap_or_default(),
+                    filter: filter__,
+                    sort_options: sort_options__.unwrap_or_default(),
+                    null_equality: null_equality__.unwrap_or_default(),
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.SortMergeJoinExecNode", 
FIELDS, GeneratedVisitor)
+    }
+}
 impl serde::Serialize for SortNode {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 8118adf323..65c807e816 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1053,7 +1053,7 @@ pub mod table_reference {
 pub struct PhysicalPlanNode {
     #[prost(
         oneof = "physical_plan_node::PhysicalPlanType",
-        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33"
+        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34"
     )]
     pub physical_plan_type: 
::core::option::Option<physical_plan_node::PhysicalPlanType>,
 }
@@ -1127,6 +1127,8 @@ pub mod physical_plan_node {
         Cooperative(::prost::alloc::boxed::Box<super::CooperativeExecNode>),
         #[prost(message, tag = "33")]
         GenerateSeries(super::GenerateSeriesNode),
+        #[prost(message, tag = "34")]
+        
SortMergeJoin(::prost::alloc::boxed::Box<super::SortMergeJoinExecNode>),
     }
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
@@ -2029,6 +2031,23 @@ pub mod generate_series_node {
         DateArgs(super::GenerateSeriesArgsDate),
     }
 }
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct SortMergeJoinExecNode {
+    #[prost(message, optional, boxed, tag = "1")]
+    pub left: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(message, optional, boxed, tag = "2")]
+    pub right: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(message, repeated, tag = "3")]
+    pub on: ::prost::alloc::vec::Vec<JoinOn>,
+    #[prost(enumeration = "super::datafusion_common::JoinType", tag = "4")]
+    pub join_type: i32,
+    #[prost(message, optional, tag = "5")]
+    pub filter: ::core::option::Option<JoinFilter>,
+    #[prost(message, repeated, tag = "6")]
+    pub sort_options: ::prost::alloc::vec::Vec<SortExprNode>,
+    #[prost(enumeration = "super::datafusion_common::NullEquality", tag = "7")]
+    pub null_equality: i32,
+}
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, 
::prost::Enumeration)]
 #[repr(i32)]
 pub enum WindowFrameUnits {
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index fb86e38055..f1e82841d0 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -34,7 +34,8 @@ use 
crate::protobuf::physical_aggregate_expr_node::AggregateFunction;
 use crate::protobuf::physical_expr_node::ExprType;
 use crate::protobuf::physical_plan_node::PhysicalPlanType;
 use crate::protobuf::{
-    self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest,
+    self, proto_error, window_agg_exec_node, ListUnnest as ProtoListUnnest, 
SortExprNode,
+    SortMergeJoinExecNode,
 };
 use crate::{convert_required, into_required};
 
@@ -74,7 +75,8 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
 use datafusion::physical_plan::joins::{
-    CrossJoinExec, NestedLoopJoinExec, StreamJoinPartitionMode, 
SymmetricHashJoinExec,
+    CrossJoinExec, NestedLoopJoinExec, SortMergeJoinExec, 
StreamJoinPartitionMode,
+    SymmetricHashJoinExec,
 };
 use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
 use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
@@ -294,6 +296,9 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
             PhysicalPlanType::GenerateSeries(generate_series) => {
                 self.try_into_generate_series_physical_plan(generate_series)
             }
+            PhysicalPlanType::SortMergeJoin(sort_join) => {
+                self.try_into_sort_join(sort_join, ctx, runtime, 
extension_codec)
+            }
         }
     }
 
@@ -363,6 +368,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode {
             );
         }
 
+        if let Some(exec) = plan.downcast_ref::<SortMergeJoinExec>() {
+            return protobuf::PhysicalPlanNode::try_from_sort_merge_join_exec(
+                exec,
+                extension_codec,
+            );
+        }
+
         if let Some(exec) = plan.downcast_ref::<CrossJoinExec>() {
             return protobuf::PhysicalPlanNode::try_from_cross_join_exec(
                 exec,
@@ -1752,6 +1764,117 @@ impl protobuf::PhysicalPlanNode {
             protobuf::GenerateSeriesName::GsRange => "range",
         }
     }
+    fn try_into_sort_join(
+        &self,
+        sort_join: &SortMergeJoinExecNode,
+        ctx: &SessionContext,
+        runtime: &RuntimeEnv,
+        extension_codec: &dyn PhysicalExtensionCodec,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let left = into_physical_plan(&sort_join.left, ctx, runtime, 
extension_codec)?;
+        let left_schema = left.schema();
+        let right = into_physical_plan(&sort_join.right, ctx, runtime, 
extension_codec)?;
+        let right_schema = right.schema();
+
+        let filter = sort_join
+            .filter
+            .as_ref()
+            .map(|f| {
+                let schema = f
+                    .schema
+                    .as_ref()
+                    .ok_or_else(|| proto_error("Missing JoinFilter schema"))?
+                    .try_into()?;
+
+                let expression = parse_physical_expr(
+                    f.expression.as_ref().ok_or_else(|| {
+                        proto_error("Unexpected empty filter expression")
+                    })?,
+                    ctx,
+                    &schema,
+                    extension_codec,
+                )?;
+                let column_indices = f
+                    .column_indices
+                    .iter()
+                    .map(|i| {
+                        let side =
+                            protobuf::JoinSide::try_from(i.side).map_err(|_| {
+                                proto_error(format!(
+                                    "Received a SortMergeJoinExecNode message 
with JoinSide in Filter {}",
+                                    i.side
+                                ))
+                            })?;
+
+                        Ok(ColumnIndex {
+                            index: i.index as usize,
+                            side: side.into(),
+                        })
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+
+                Ok(JoinFilter::new(
+                    expression,
+                    column_indices,
+                    Arc::new(schema),
+                ))
+            })
+            .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+        let join_type =
+            protobuf::JoinType::try_from(sort_join.join_type).map_err(|_| {
+                proto_error(format!(
+                    "Received a SortMergeJoinExecNode message with unknown 
JoinType {}",
+                    sort_join.join_type
+                ))
+            })?;
+
+        let null_equality = 
protobuf::NullEquality::try_from(sort_join.null_equality)
+            .map_err(|_| {
+                proto_error(format!(
+                    "Received a SortMergeJoinExecNode message with unknown 
NullEquality {}",
+                    sort_join.null_equality
+                ))
+            })?;
+
+        let sort_options = sort_join
+            .sort_options
+            .iter()
+            .map(|e| SortOptions {
+                descending: !e.asc,
+                nulls_first: e.nulls_first,
+            })
+            .collect();
+        let on = sort_join
+            .on
+            .iter()
+            .map(|col| {
+                let left = parse_physical_expr(
+                    &col.left.clone().unwrap(),
+                    ctx,
+                    left_schema.as_ref(),
+                    extension_codec,
+                )?;
+                let right = parse_physical_expr(
+                    &col.right.clone().unwrap(),
+                    ctx,
+                    right_schema.as_ref(),
+                    extension_codec,
+                )?;
+                Ok((left, right))
+            })
+            .collect::<Result<_>>()?;
+
+        Ok(Arc::new(SortMergeJoinExec::try_new(
+            left,
+            right,
+            on,
+            filter,
+            join_type.into(),
+            sort_options,
+            null_equality.into(),
+        )?))
+    }
 
     fn try_into_generate_series_physical_plan(
         &self,
@@ -2154,6 +2277,90 @@ impl protobuf::PhysicalPlanNode {
         })
     }
 
+    fn try_from_sort_merge_join_exec(
+        exec: &SortMergeJoinExec,
+        extension_codec: &dyn PhysicalExtensionCodec,
+    ) -> Result<Self> {
+        let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+            exec.left().to_owned(),
+            extension_codec,
+        )?;
+        let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+            exec.right().to_owned(),
+            extension_codec,
+        )?;
+        let on = exec
+            .on()
+            .iter()
+            .map(|tuple| {
+                let l = serialize_physical_expr(&tuple.0, extension_codec)?;
+                let r = serialize_physical_expr(&tuple.1, extension_codec)?;
+                Ok::<_, DataFusionError>(protobuf::JoinOn {
+                    left: Some(l),
+                    right: Some(r),
+                })
+            })
+            .collect::<Result<_>>()?;
+        let join_type: protobuf::JoinType = exec.join_type().to_owned().into();
+        let null_equality: protobuf::NullEquality = 
exec.null_equality().into();
+        let filter = exec
+            .filter()
+            .as_ref()
+            .map(|f| {
+                let expression =
+                    serialize_physical_expr(f.expression(), extension_codec)?;
+                let column_indices = f
+                    .column_indices()
+                    .iter()
+                    .map(|i| {
+                        let side: protobuf::JoinSide = 
i.side.to_owned().into();
+                        protobuf::ColumnIndex {
+                            index: i.index as u32,
+                            side: side.into(),
+                        }
+                    })
+                    .collect();
+                let schema = f.schema().as_ref().try_into()?;
+                Ok(protobuf::JoinFilter {
+                    expression: Some(expression),
+                    column_indices,
+                    schema: Some(schema),
+                })
+            })
+            .map_or(Ok(None), |v: Result<protobuf::JoinFilter>| v.map(Some))?;
+
+        let sort_options = exec
+            .sort_options()
+            .iter()
+            .map(
+                |SortOptions {
+                     descending,
+                     nulls_first,
+                 }| {
+                    SortExprNode {
+                        expr: None,
+                        asc: !*descending,
+                        nulls_first: *nulls_first,
+                    }
+                },
+            )
+            .collect();
+
+        Ok(protobuf::PhysicalPlanNode {
+            physical_plan_type: Some(PhysicalPlanType::SortMergeJoin(Box::new(
+                protobuf::SortMergeJoinExecNode {
+                    left: Some(Box::new(left)),
+                    right: Some(Box::new(right)),
+                    on,
+                    join_type: join_type.into(),
+                    null_equality: null_equality.into(),
+                    filter,
+                    sort_options,
+                },
+            ))),
+        })
+    }
+
     fn try_from_cross_join_exec(
         exec: &CrossJoinExec,
         extension_codec: &dyn PhysicalExtensionCodec,
diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs 
b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
index 1547b7087d..86ad54d3f1 100644
--- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
+++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs
@@ -75,8 +75,8 @@ use datafusion::physical_plan::expressions::{
 };
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::joins::{
-    HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode,
-    SymmetricHashJoinExec,
+    HashJoinExec, NestedLoopJoinExec, PartitionMode, SortMergeJoinExec,
+    StreamJoinPartitionMode, SymmetricHashJoinExec,
 };
 use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
 use datafusion::physical_plan::placeholder_row::PlaceholderRowExec;
@@ -2125,3 +2125,87 @@ async fn analyze_roundtrip_unoptimized() -> Result<()> {
     physical_planner.optimize_physical_plan(unoptimized, &session_state, |_, 
_| {})?;
     Ok(())
 }
+
+#[test]
+fn roundtrip_sort_merge_join() -> Result<()> {
+    let field_a = Field::new("col_a", DataType::Int64, false);
+    let field_b = Field::new("col_b", DataType::Int64, false);
+    let schema_left = Schema::new(vec![field_a.clone()]);
+    let schema_right = Schema::new(vec![field_b.clone()]);
+    let on = vec![(
+        Arc::new(Column::new("col_a", schema_left.index_of("col_a")?)) as _,
+        Arc::new(Column::new("col_b", schema_right.index_of("col_b")?)) as _,
+    )];
+
+    let filter = datafusion::physical_plan::joins::utils::JoinFilter::new(
+        Arc::new(BinaryExpr::new(
+            Arc::new(Column::new("col_a", 1)),
+            Operator::Gt,
+            Arc::new(Column::new("col_b", 0)),
+        )),
+        vec![
+            datafusion::physical_plan::joins::utils::ColumnIndex {
+                index: 0,
+                side: datafusion_common::JoinSide::Left,
+            },
+            datafusion::physical_plan::joins::utils::ColumnIndex {
+                index: 0,
+                side: datafusion_common::JoinSide::Right,
+            },
+        ],
+        Arc::new(Schema::new(vec![field_a, field_b])),
+    );
+
+    let schema_left = Arc::new(schema_left);
+    let schema_right = Arc::new(schema_right);
+    for filter in [None, Some(filter)] {
+        for join_type in [
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftAnti,
+            JoinType::RightAnti,
+            JoinType::LeftSemi,
+            JoinType::RightSemi,
+        ] {
+            roundtrip_test(Arc::new(SortMergeJoinExec::try_new(
+                Arc::new(EmptyExec::new(schema_left.clone())),
+                Arc::new(EmptyExec::new(schema_right.clone())),
+                on.clone(),
+                filter.clone(),
+                join_type,
+                vec![Default::default()],
+                NullEquality::NullEqualsNothing,
+            )?))?;
+        }
+    }
+    Ok(())
+}
+
+#[tokio::test]
+async fn roundtrip_logical_plan_sort_merge_join() -> Result<()> {
+    let ctx = SessionContext::new();
+    ctx.register_csv(
+        "t0",
+        "tests/testdata/test.csv",
+        datafusion::prelude::CsvReadOptions::default().has_header(true),
+    )
+    .await?;
+    ctx.register_csv(
+        "t1",
+        "tests/testdata/test.csv",
+        datafusion::prelude::CsvReadOptions::default().has_header(true),
+    )
+    .await?;
+
+    ctx.sql("SET datafusion.optimizer.prefer_hash_join = false")
+        .await?
+        .show()
+        .await?;
+
+    let query = "SELECT t1.* FROM t0 join t1 on t0.a = t1.a";
+    let plan = ctx.sql(query).await?.create_physical_plan().await?;
+
+    roundtrip_test(plan)
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to