This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e2a72f6c7 feat: support for `NestedLoopJoinExec` in datafusion-proto 
(#6902)
4e2a72f6c7 is described below

commit 4e2a72f6c7109d40a4986e3d05360524be078dd4
Author: r.4ntix <[email protected]>
AuthorDate: Tue Jul 11 15:49:48 2023 +0800

    feat: support for `NestedLoopJoinExec` in datafusion-proto (#6902)
---
 .../src/physical_plan/joins/nested_loop_join.rs    |  20 +++
 datafusion/proto/proto/datafusion.proto            |   8 ++
 datafusion/proto/src/generated/pbjson.rs           | 159 +++++++++++++++++++++
 datafusion/proto/src/generated/prost.rs            |  16 ++-
 datafusion/proto/src/physical_plan/mod.rs          | 133 ++++++++++++++++-
 5 files changed, 333 insertions(+), 3 deletions(-)

diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs 
b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index fe7d1a7c69..ed63fb8448 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -118,6 +118,26 @@ impl NestedLoopJoinExec {
             metrics: Default::default(),
         })
     }
+
+    /// left (build) side which gets hashed
+    pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
+        &self.left
+    }
+
+    /// right (probe) side which are filtered by the hash table
+    pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
+        &self.right
+    }
+
+    /// Filters applied before join output
+    pub fn filter(&self) -> Option<&JoinFilter> {
+        self.filter.as_ref()
+    }
+
+    /// How the join is performed
+    pub fn join_type(&self) -> &JoinType {
+        &self.join_type
+    }
 }
 
 impl DisplayAs for NestedLoopJoinExec {
diff --git a/datafusion/proto/proto/datafusion.proto 
b/datafusion/proto/proto/datafusion.proto
index 528c675570..89bca57cf3 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1059,6 +1059,7 @@ message PhysicalPlanNode {
     UnionExecNode union = 19;
     ExplainExecNode explain = 20;
     SortPreservingMergeExecNode sort_preserving_merge = 21;
+    NestedLoopJoinExecNode nested_loop_join = 22;
   }
 }
 
@@ -1380,6 +1381,13 @@ message SortPreservingMergeExecNode {
   int64 fetch = 3;
 }
 
+message NestedLoopJoinExecNode {
+  PhysicalPlanNode left = 1;
+  PhysicalPlanNode right = 2;
+  JoinType join_type = 3;
+  JoinFilter filter = 4;
+}
+
 message CoalesceBatchesExecNode {
   PhysicalPlanNode input = 1;
   uint32 target_batch_size = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs 
b/datafusion/proto/src/generated/pbjson.rs
index d6a770159b..590b462ad8 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -12010,6 +12010,151 @@ impl<'de> serde::Deserialize<'de> for NegativeNode {
         deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS, 
GeneratedVisitor)
     }
 }
+impl serde::Serialize for NestedLoopJoinExecNode {
+    #[allow(deprecated)]
+    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
+    where
+        S: serde::Serializer,
+    {
+        use serde::ser::SerializeStruct;
+        let mut len = 0;
+        if self.left.is_some() {
+            len += 1;
+        }
+        if self.right.is_some() {
+            len += 1;
+        }
+        if self.join_type != 0 {
+            len += 1;
+        }
+        if self.filter.is_some() {
+            len += 1;
+        }
+        let mut struct_ser = 
serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?;
+        if let Some(v) = self.left.as_ref() {
+            struct_ser.serialize_field("left", v)?;
+        }
+        if let Some(v) = self.right.as_ref() {
+            struct_ser.serialize_field("right", v)?;
+        }
+        if self.join_type != 0 {
+            let v = JoinType::from_i32(self.join_type)
+                .ok_or_else(|| serde::ser::Error::custom(format!("Invalid 
variant {}", self.join_type)))?;
+            struct_ser.serialize_field("joinType", &v)?;
+        }
+        if let Some(v) = self.filter.as_ref() {
+            struct_ser.serialize_field("filter", v)?;
+        }
+        struct_ser.end()
+    }
+}
+impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode {
+    #[allow(deprecated)]
+    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+    where
+        D: serde::Deserializer<'de>,
+    {
+        const FIELDS: &[&str] = &[
+            "left",
+            "right",
+            "join_type",
+            "joinType",
+            "filter",
+        ];
+
+        #[allow(clippy::enum_variant_names)]
+        enum GeneratedField {
+            Left,
+            Right,
+            JoinType,
+            Filter,
+        }
+        impl<'de> serde::Deserialize<'de> for GeneratedField {
+            fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
+            where
+                D: serde::Deserializer<'de>,
+            {
+                struct GeneratedVisitor;
+
+                impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+                    type Value = GeneratedField;
+
+                    fn expecting(&self, formatter: &mut 
std::fmt::Formatter<'_>) -> std::fmt::Result {
+                        write!(formatter, "expected one of: {:?}", &FIELDS)
+                    }
+
+                    #[allow(unused_variables)]
+                    fn visit_str<E>(self, value: &str) -> 
std::result::Result<GeneratedField, E>
+                    where
+                        E: serde::de::Error,
+                    {
+                        match value {
+                            "left" => Ok(GeneratedField::Left),
+                            "right" => Ok(GeneratedField::Right),
+                            "joinType" | "join_type" => 
Ok(GeneratedField::JoinType),
+                            "filter" => Ok(GeneratedField::Filter),
+                            _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
+                        }
+                    }
+                }
+                deserializer.deserialize_identifier(GeneratedVisitor)
+            }
+        }
+        struct GeneratedVisitor;
+        impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+            type Value = NestedLoopJoinExecNode;
+
+            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> 
std::fmt::Result {
+                formatter.write_str("struct datafusion.NestedLoopJoinExecNode")
+            }
+
+            fn visit_map<V>(self, mut map: V) -> 
std::result::Result<NestedLoopJoinExecNode, V::Error>
+                where
+                    V: serde::de::MapAccess<'de>,
+            {
+                let mut left__ = None;
+                let mut right__ = None;
+                let mut join_type__ = None;
+                let mut filter__ = None;
+                while let Some(k) = map.next_key()? {
+                    match k {
+                        GeneratedField::Left => {
+                            if left__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("left"));
+                            }
+                            left__ = map.next_value()?;
+                        }
+                        GeneratedField::Right => {
+                            if right__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("right"));
+                            }
+                            right__ = map.next_value()?;
+                        }
+                        GeneratedField::JoinType => {
+                            if join_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("joinType"));
+                            }
+                            join_type__ = Some(map.next_value::<JoinType>()? 
as i32);
+                        }
+                        GeneratedField::Filter => {
+                            if filter__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("filter"));
+                            }
+                            filter__ = map.next_value()?;
+                        }
+                    }
+                }
+                Ok(NestedLoopJoinExecNode {
+                    left: left__,
+                    right: right__,
+                    join_type: join_type__.unwrap_or_default(),
+                    filter: filter__,
+                })
+            }
+        }
+        deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode", 
FIELDS, GeneratedVisitor)
+    }
+}
 impl serde::Serialize for Not {
     #[allow(deprecated)]
     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, 
S::Error>
@@ -15335,6 +15480,9 @@ impl serde::Serialize for PhysicalPlanNode {
                 physical_plan_node::PhysicalPlanType::SortPreservingMerge(v) 
=> {
                     struct_ser.serialize_field("sortPreservingMerge", v)?;
                 }
+                physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => {
+                    struct_ser.serialize_field("nestedLoopJoin", v)?;
+                }
             }
         }
         struct_ser.end()
@@ -15376,6 +15524,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             "explain",
             "sort_preserving_merge",
             "sortPreservingMerge",
+            "nested_loop_join",
+            "nestedLoopJoin",
         ];
 
         #[allow(clippy::enum_variant_names)]
@@ -15400,6 +15550,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
             Union,
             Explain,
             SortPreservingMerge,
+            NestedLoopJoin,
         }
         impl<'de> serde::Deserialize<'de> for GeneratedField {
             fn deserialize<D>(deserializer: D) -> 
std::result::Result<GeneratedField, D::Error>
@@ -15441,6 +15592,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
                             "union" => Ok(GeneratedField::Union),
                             "explain" => Ok(GeneratedField::Explain),
                             "sortPreservingMerge" | "sort_preserving_merge" => 
Ok(GeneratedField::SortPreservingMerge),
+                            "nestedLoopJoin" | "nested_loop_join" => 
Ok(GeneratedField::NestedLoopJoin),
                             _ => Err(serde::de::Error::unknown_field(value, 
FIELDS)),
                         }
                     }
@@ -15601,6 +15753,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode 
{
                                 return 
Err(serde::de::Error::duplicate_field("sortPreservingMerge"));
                             }
                             physical_plan_type__ = 
map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge)
+;
+                        }
+                        GeneratedField::NestedLoopJoin => {
+                            if physical_plan_type__.is_some() {
+                                return 
Err(serde::de::Error::duplicate_field("nestedLoopJoin"));
+                            }
+                            physical_plan_type__ = 
map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin)
 ;
                         }
                     }
diff --git a/datafusion/proto/src/generated/prost.rs 
b/datafusion/proto/src/generated/prost.rs
index 4e91fbab19..251760f090 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1394,7 +1394,7 @@ pub mod owned_table_reference {
 pub struct PhysicalPlanNode {
     #[prost(
         oneof = "physical_plan_node::PhysicalPlanType",
-        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21"
+        tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 
19, 20, 21, 22"
     )]
     pub physical_plan_type: 
::core::option::Option<physical_plan_node::PhysicalPlanType>,
 }
@@ -1445,6 +1445,8 @@ pub mod physical_plan_node {
         SortPreservingMerge(
             ::prost::alloc::boxed::Box<super::SortPreservingMergeExecNode>,
         ),
+        #[prost(message, tag = "22")]
+        
NestedLoopJoin(::prost::alloc::boxed::Box<super::NestedLoopJoinExecNode>),
     }
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
@@ -1950,6 +1952,18 @@ pub struct SortPreservingMergeExecNode {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct NestedLoopJoinExecNode {
+    #[prost(message, optional, boxed, tag = "1")]
+    pub left: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(message, optional, boxed, tag = "2")]
+    pub right: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+    #[prost(enumeration = "JoinType", tag = "3")]
+    pub join_type: i32,
+    #[prost(message, optional, tag = "4")]
+    pub filter: ::core::option::Option<JoinFilter>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct CoalesceBatchesExecNode {
     #[prost(message, optional, boxed, tag = "1")]
     pub input: 
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
diff --git a/datafusion/proto/src/physical_plan/mod.rs 
b/datafusion/proto/src/physical_plan/mod.rs
index 7bbbe13568..b5b4aeb2da 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -35,7 +35,7 @@ use datafusion::physical_plan::explain::ExplainExec;
 use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
 use datafusion::physical_plan::filter::FilterExec;
 use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
-use datafusion::physical_plan::joins::CrossJoinExec;
+use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
 use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
 use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
 use datafusion::physical_plan::projection::ProjectionExec;
@@ -716,6 +716,61 @@ impl AsExecutionPlan for PhysicalPlanNode {
 
                 Ok(extension_node)
             }
+            PhysicalPlanType::NestedLoopJoin(join) => {
+                let left: Arc<dyn ExecutionPlan> =
+                    into_physical_plan(&join.left, registry, runtime, 
extension_codec)?;
+                let right: Arc<dyn ExecutionPlan> =
+                    into_physical_plan(&join.right, registry, runtime, 
extension_codec)?;
+                let join_type =
+                    protobuf::JoinType::from_i32(join.join_type).ok_or_else(|| 
{
+                        proto_error(format!(
+                            "Received a NestedLoopJoinExecNode message with 
unknown JoinType {}",
+                            join.join_type
+                        ))
+                    })?;
+                let filter = join
+                    .filter
+                    .as_ref()
+                    .map(|f| {
+                        let schema = f
+                            .schema
+                            .as_ref()
+                            .ok_or_else(|| proto_error("Missing JoinFilter 
schema"))?
+                            .try_into()?;
+
+                        let expression = parse_physical_expr(
+                            f.expression.as_ref().ok_or_else(|| {
+                                proto_error("Unexpected empty filter 
expression")
+                            })?,
+                            registry, &schema
+                        )?;
+                        let column_indices = f.column_indices
+                            .iter()
+                            .map(|i| {
+                                let side = protobuf::JoinSide::from_i32(i.side)
+                                    .ok_or_else(|| proto_error(format!(
+                                        "Received a NestedLoopJoinExecNode 
message with JoinSide in Filter {}",
+                                        i.side))
+                                    )?;
+
+                                Ok(ColumnIndex{
+                                    index: i.index as usize,
+                                    side: side.into(),
+                                })
+                            })
+                            .collect::<Result<Vec<_>>>()?;
+
+                        Ok(JoinFilter::new(expression, column_indices, schema))
+                    })
+                    .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+                Ok(Arc::new(NestedLoopJoinExec::try_new(
+                    left,
+                    right,
+                    filter,
+                    &join_type.into(),
+                )?))
+            }
         }
     }
 
@@ -1155,6 +1210,52 @@ impl AsExecutionPlan for PhysicalPlanNode {
                     }),
                 )),
             })
+        } else if let Some(exec) = plan.downcast_ref::<NestedLoopJoinExec>() {
+            let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec.left().to_owned(),
+                extension_codec,
+            )?;
+            let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+                exec.right().to_owned(),
+                extension_codec,
+            )?;
+
+            let join_type: protobuf::JoinType = 
exec.join_type().to_owned().into();
+            let filter = exec
+                .filter()
+                .as_ref()
+                .map(|f| {
+                    let expression = f.expression().to_owned().try_into()?;
+                    let column_indices = f
+                        .column_indices()
+                        .iter()
+                        .map(|i| {
+                            let side: protobuf::JoinSide = 
i.side.to_owned().into();
+                            protobuf::ColumnIndex {
+                                index: i.index as u32,
+                                side: side.into(),
+                            }
+                        })
+                        .collect();
+                    let schema = f.schema().try_into()?;
+                    Ok(protobuf::JoinFilter {
+                        expression: Some(expression),
+                        column_indices,
+                        schema: Some(schema),
+                    })
+                })
+                .map_or(Ok(None), |v: Result<protobuf::JoinFilter>| 
v.map(Some))?;
+
+            Ok(protobuf::PhysicalPlanNode {
+                physical_plan_type: 
Some(PhysicalPlanType::NestedLoopJoin(Box::new(
+                    protobuf::NestedLoopJoinExecNode {
+                        left: Some(Box::new(left)),
+                        right: Some(Box::new(right)),
+                        join_type: join_type.into(),
+                        filter,
+                    },
+                ))),
+            })
         } else {
             let mut buf: Vec<u8> = vec![];
             match extension_codec.try_encode(plan_clone.clone(), &mut buf) {
@@ -1297,7 +1398,7 @@ mod roundtrip_tests {
             expressions::{binary, col, lit, NotExpr},
             expressions::{Avg, Column, DistinctCount, PhysicalSortExpr},
             filter::FilterExec,
-            joins::{HashJoinExec, PartitionMode},
+            joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
             limit::{GlobalLimitExec, LocalLimitExec},
             sorts::sort::SortExec,
             AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics,
@@ -1433,6 +1534,34 @@ mod roundtrip_tests {
         Ok(())
     }
 
+    #[test]
+    fn roundtrip_nested_loop_join() -> Result<()> {
+        let field_a = Field::new("col", DataType::Int64, false);
+        let schema_left = Schema::new(vec![field_a.clone()]);
+        let schema_right = Schema::new(vec![field_a]);
+
+        let schema_left = Arc::new(schema_left);
+        let schema_right = Arc::new(schema_right);
+        for join_type in &[
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftAnti,
+            JoinType::RightAnti,
+            JoinType::LeftSemi,
+            JoinType::RightSemi,
+        ] {
+            roundtrip_test(Arc::new(NestedLoopJoinExec::try_new(
+                Arc::new(EmptyExec::new(false, schema_left.clone())),
+                Arc::new(EmptyExec::new(false, schema_right.clone())),
+                None,
+                join_type,
+            )?))?;
+        }
+        Ok(())
+    }
+
     #[test]
     fn rountrip_aggregate() -> Result<()> {
         let field_a = Field::new("a", DataType::Int64, false);

Reply via email to