This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 4e2a72f6c7 feat: support for `NestedLoopJoinExec` in datafusion-proto
(#6902)
4e2a72f6c7 is described below
commit 4e2a72f6c7109d40a4986e3d05360524be078dd4
Author: r.4ntix <[email protected]>
AuthorDate: Tue Jul 11 15:49:48 2023 +0800
feat: support for `NestedLoopJoinExec` in datafusion-proto (#6902)
---
.../src/physical_plan/joins/nested_loop_join.rs | 20 +++
datafusion/proto/proto/datafusion.proto | 8 ++
datafusion/proto/src/generated/pbjson.rs | 159 +++++++++++++++++++++
datafusion/proto/src/generated/prost.rs | 16 ++-
datafusion/proto/src/physical_plan/mod.rs | 133 ++++++++++++++++-
5 files changed, 333 insertions(+), 3 deletions(-)
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index fe7d1a7c69..ed63fb8448 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -118,6 +118,26 @@ impl NestedLoopJoinExec {
metrics: Default::default(),
})
}
+
+ /// left (build) side which gets hashed
+ pub fn left(&self) -> &Arc<dyn ExecutionPlan> {
+ &self.left
+ }
+
+ /// right (probe) side which are filtered by the hash table
+ pub fn right(&self) -> &Arc<dyn ExecutionPlan> {
+ &self.right
+ }
+
+ /// Filters applied before join output
+ pub fn filter(&self) -> Option<&JoinFilter> {
+ self.filter.as_ref()
+ }
+
+ /// How the join is performed
+ pub fn join_type(&self) -> &JoinType {
+ &self.join_type
+ }
}
impl DisplayAs for NestedLoopJoinExec {
diff --git a/datafusion/proto/proto/datafusion.proto
b/datafusion/proto/proto/datafusion.proto
index 528c675570..89bca57cf3 100644
--- a/datafusion/proto/proto/datafusion.proto
+++ b/datafusion/proto/proto/datafusion.proto
@@ -1059,6 +1059,7 @@ message PhysicalPlanNode {
UnionExecNode union = 19;
ExplainExecNode explain = 20;
SortPreservingMergeExecNode sort_preserving_merge = 21;
+ NestedLoopJoinExecNode nested_loop_join = 22;
}
}
@@ -1380,6 +1381,13 @@ message SortPreservingMergeExecNode {
int64 fetch = 3;
}
+message NestedLoopJoinExecNode {
+ PhysicalPlanNode left = 1;
+ PhysicalPlanNode right = 2;
+ JoinType join_type = 3;
+ JoinFilter filter = 4;
+}
+
message CoalesceBatchesExecNode {
PhysicalPlanNode input = 1;
uint32 target_batch_size = 2;
diff --git a/datafusion/proto/src/generated/pbjson.rs
b/datafusion/proto/src/generated/pbjson.rs
index d6a770159b..590b462ad8 100644
--- a/datafusion/proto/src/generated/pbjson.rs
+++ b/datafusion/proto/src/generated/pbjson.rs
@@ -12010,6 +12010,151 @@ impl<'de> serde::Deserialize<'de> for NegativeNode {
deserializer.deserialize_struct("datafusion.NegativeNode", FIELDS,
GeneratedVisitor)
}
}
+impl serde::Serialize for NestedLoopJoinExecNode {
+ #[allow(deprecated)]
+ fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
+ where
+ S: serde::Serializer,
+ {
+ use serde::ser::SerializeStruct;
+ let mut len = 0;
+ if self.left.is_some() {
+ len += 1;
+ }
+ if self.right.is_some() {
+ len += 1;
+ }
+ if self.join_type != 0 {
+ len += 1;
+ }
+ if self.filter.is_some() {
+ len += 1;
+ }
+ let mut struct_ser =
serializer.serialize_struct("datafusion.NestedLoopJoinExecNode", len)?;
+ if let Some(v) = self.left.as_ref() {
+ struct_ser.serialize_field("left", v)?;
+ }
+ if let Some(v) = self.right.as_ref() {
+ struct_ser.serialize_field("right", v)?;
+ }
+ if self.join_type != 0 {
+ let v = JoinType::from_i32(self.join_type)
+ .ok_or_else(|| serde::ser::Error::custom(format!("Invalid
variant {}", self.join_type)))?;
+ struct_ser.serialize_field("joinType", &v)?;
+ }
+ if let Some(v) = self.filter.as_ref() {
+ struct_ser.serialize_field("filter", v)?;
+ }
+ struct_ser.end()
+ }
+}
+impl<'de> serde::Deserialize<'de> for NestedLoopJoinExecNode {
+ #[allow(deprecated)]
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ const FIELDS: &[&str] = &[
+ "left",
+ "right",
+ "join_type",
+ "joinType",
+ "filter",
+ ];
+
+ #[allow(clippy::enum_variant_names)]
+ enum GeneratedField {
+ Left,
+ Right,
+ JoinType,
+ Filter,
+ }
+ impl<'de> serde::Deserialize<'de> for GeneratedField {
+ fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
+ where
+ D: serde::Deserializer<'de>,
+ {
+ struct GeneratedVisitor;
+
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = GeneratedField;
+
+ fn expecting(&self, formatter: &mut
std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(formatter, "expected one of: {:?}", &FIELDS)
+ }
+
+ #[allow(unused_variables)]
+ fn visit_str<E>(self, value: &str) ->
std::result::Result<GeneratedField, E>
+ where
+ E: serde::de::Error,
+ {
+ match value {
+ "left" => Ok(GeneratedField::Left),
+ "right" => Ok(GeneratedField::Right),
+ "joinType" | "join_type" =>
Ok(GeneratedField::JoinType),
+ "filter" => Ok(GeneratedField::Filter),
+ _ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
+ }
+ }
+ }
+ deserializer.deserialize_identifier(GeneratedVisitor)
+ }
+ }
+ struct GeneratedVisitor;
+ impl<'de> serde::de::Visitor<'de> for GeneratedVisitor {
+ type Value = NestedLoopJoinExecNode;
+
+ fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) ->
std::fmt::Result {
+ formatter.write_str("struct datafusion.NestedLoopJoinExecNode")
+ }
+
+ fn visit_map<V>(self, mut map: V) ->
std::result::Result<NestedLoopJoinExecNode, V::Error>
+ where
+ V: serde::de::MapAccess<'de>,
+ {
+ let mut left__ = None;
+ let mut right__ = None;
+ let mut join_type__ = None;
+ let mut filter__ = None;
+ while let Some(k) = map.next_key()? {
+ match k {
+ GeneratedField::Left => {
+ if left__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("left"));
+ }
+ left__ = map.next_value()?;
+ }
+ GeneratedField::Right => {
+ if right__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("right"));
+ }
+ right__ = map.next_value()?;
+ }
+ GeneratedField::JoinType => {
+ if join_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("joinType"));
+ }
+ join_type__ = Some(map.next_value::<JoinType>()?
as i32);
+ }
+ GeneratedField::Filter => {
+ if filter__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("filter"));
+ }
+ filter__ = map.next_value()?;
+ }
+ }
+ }
+ Ok(NestedLoopJoinExecNode {
+ left: left__,
+ right: right__,
+ join_type: join_type__.unwrap_or_default(),
+ filter: filter__,
+ })
+ }
+ }
+ deserializer.deserialize_struct("datafusion.NestedLoopJoinExecNode",
FIELDS, GeneratedVisitor)
+ }
+}
impl serde::Serialize for Not {
#[allow(deprecated)]
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok,
S::Error>
@@ -15335,6 +15480,9 @@ impl serde::Serialize for PhysicalPlanNode {
physical_plan_node::PhysicalPlanType::SortPreservingMerge(v)
=> {
struct_ser.serialize_field("sortPreservingMerge", v)?;
}
+ physical_plan_node::PhysicalPlanType::NestedLoopJoin(v) => {
+ struct_ser.serialize_field("nestedLoopJoin", v)?;
+ }
}
}
struct_ser.end()
@@ -15376,6 +15524,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"explain",
"sort_preserving_merge",
"sortPreservingMerge",
+ "nested_loop_join",
+ "nestedLoopJoin",
];
#[allow(clippy::enum_variant_names)]
@@ -15400,6 +15550,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
Union,
Explain,
SortPreservingMerge,
+ NestedLoopJoin,
}
impl<'de> serde::Deserialize<'de> for GeneratedField {
fn deserialize<D>(deserializer: D) ->
std::result::Result<GeneratedField, D::Error>
@@ -15441,6 +15592,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode {
"union" => Ok(GeneratedField::Union),
"explain" => Ok(GeneratedField::Explain),
"sortPreservingMerge" | "sort_preserving_merge" =>
Ok(GeneratedField::SortPreservingMerge),
+ "nestedLoopJoin" | "nested_loop_join" =>
Ok(GeneratedField::NestedLoopJoin),
_ => Err(serde::de::Error::unknown_field(value,
FIELDS)),
}
}
@@ -15601,6 +15753,13 @@ impl<'de> serde::Deserialize<'de> for PhysicalPlanNode
{
return
Err(serde::de::Error::duplicate_field("sortPreservingMerge"));
}
physical_plan_type__ =
map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::SortPreservingMerge)
+;
+ }
+ GeneratedField::NestedLoopJoin => {
+ if physical_plan_type__.is_some() {
+ return
Err(serde::de::Error::duplicate_field("nestedLoopJoin"));
+ }
+ physical_plan_type__ =
map.next_value::<::std::option::Option<_>>()?.map(physical_plan_node::PhysicalPlanType::NestedLoopJoin)
;
}
}
diff --git a/datafusion/proto/src/generated/prost.rs
b/datafusion/proto/src/generated/prost.rs
index 4e91fbab19..251760f090 100644
--- a/datafusion/proto/src/generated/prost.rs
+++ b/datafusion/proto/src/generated/prost.rs
@@ -1394,7 +1394,7 @@ pub mod owned_table_reference {
pub struct PhysicalPlanNode {
#[prost(
oneof = "physical_plan_node::PhysicalPlanType",
- tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21"
+ tags = "1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
19, 20, 21, 22"
)]
pub physical_plan_type:
::core::option::Option<physical_plan_node::PhysicalPlanType>,
}
@@ -1445,6 +1445,8 @@ pub mod physical_plan_node {
SortPreservingMerge(
::prost::alloc::boxed::Box<super::SortPreservingMergeExecNode>,
),
+ #[prost(message, tag = "22")]
+
NestedLoopJoin(::prost::alloc::boxed::Box<super::NestedLoopJoinExecNode>),
}
}
#[allow(clippy::derive_partial_eq_without_eq)]
@@ -1950,6 +1952,18 @@ pub struct SortPreservingMergeExecNode {
}
#[allow(clippy::derive_partial_eq_without_eq)]
#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct NestedLoopJoinExecNode {
+ #[prost(message, optional, boxed, tag = "1")]
+ pub left:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(message, optional, boxed, tag = "2")]
+ pub right:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
+ #[prost(enumeration = "JoinType", tag = "3")]
+ pub join_type: i32,
+ #[prost(message, optional, tag = "4")]
+ pub filter: ::core::option::Option<JoinFilter>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
pub struct CoalesceBatchesExecNode {
#[prost(message, optional, boxed, tag = "1")]
pub input:
::core::option::Option<::prost::alloc::boxed::Box<PhysicalPlanNode>>,
diff --git a/datafusion/proto/src/physical_plan/mod.rs
b/datafusion/proto/src/physical_plan/mod.rs
index 7bbbe13568..b5b4aeb2da 100644
--- a/datafusion/proto/src/physical_plan/mod.rs
+++ b/datafusion/proto/src/physical_plan/mod.rs
@@ -35,7 +35,7 @@ use datafusion::physical_plan::explain::ExplainExec;
use datafusion::physical_plan::expressions::{Column, PhysicalSortExpr};
use datafusion::physical_plan::filter::FilterExec;
use datafusion::physical_plan::joins::utils::{ColumnIndex, JoinFilter};
-use datafusion::physical_plan::joins::CrossJoinExec;
+use datafusion::physical_plan::joins::{CrossJoinExec, NestedLoopJoinExec};
use datafusion::physical_plan::joins::{HashJoinExec, PartitionMode};
use datafusion::physical_plan::limit::{GlobalLimitExec, LocalLimitExec};
use datafusion::physical_plan::projection::ProjectionExec;
@@ -716,6 +716,61 @@ impl AsExecutionPlan for PhysicalPlanNode {
Ok(extension_node)
}
+ PhysicalPlanType::NestedLoopJoin(join) => {
+ let left: Arc<dyn ExecutionPlan> =
+ into_physical_plan(&join.left, registry, runtime,
extension_codec)?;
+ let right: Arc<dyn ExecutionPlan> =
+ into_physical_plan(&join.right, registry, runtime,
extension_codec)?;
+ let join_type =
+ protobuf::JoinType::from_i32(join.join_type).ok_or_else(||
{
+ proto_error(format!(
+ "Received a NestedLoopJoinExecNode message with
unknown JoinType {}",
+ join.join_type
+ ))
+ })?;
+ let filter = join
+ .filter
+ .as_ref()
+ .map(|f| {
+ let schema = f
+ .schema
+ .as_ref()
+ .ok_or_else(|| proto_error("Missing JoinFilter
schema"))?
+ .try_into()?;
+
+ let expression = parse_physical_expr(
+ f.expression.as_ref().ok_or_else(|| {
+ proto_error("Unexpected empty filter
expression")
+ })?,
+ registry, &schema
+ )?;
+ let column_indices = f.column_indices
+ .iter()
+ .map(|i| {
+ let side = protobuf::JoinSide::from_i32(i.side)
+ .ok_or_else(|| proto_error(format!(
+ "Received a NestedLoopJoinExecNode
message with JoinSide in Filter {}",
+ i.side))
+ )?;
+
+ Ok(ColumnIndex{
+ index: i.index as usize,
+ side: side.into(),
+ })
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ Ok(JoinFilter::new(expression, column_indices, schema))
+ })
+ .map_or(Ok(None), |v: Result<JoinFilter>| v.map(Some))?;
+
+ Ok(Arc::new(NestedLoopJoinExec::try_new(
+ left,
+ right,
+ filter,
+ &join_type.into(),
+ )?))
+ }
}
}
@@ -1155,6 +1210,52 @@ impl AsExecutionPlan for PhysicalPlanNode {
}),
)),
})
+ } else if let Some(exec) = plan.downcast_ref::<NestedLoopJoinExec>() {
+ let left = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.left().to_owned(),
+ extension_codec,
+ )?;
+ let right = protobuf::PhysicalPlanNode::try_from_physical_plan(
+ exec.right().to_owned(),
+ extension_codec,
+ )?;
+
+ let join_type: protobuf::JoinType =
exec.join_type().to_owned().into();
+ let filter = exec
+ .filter()
+ .as_ref()
+ .map(|f| {
+ let expression = f.expression().to_owned().try_into()?;
+ let column_indices = f
+ .column_indices()
+ .iter()
+ .map(|i| {
+ let side: protobuf::JoinSide =
i.side.to_owned().into();
+ protobuf::ColumnIndex {
+ index: i.index as u32,
+ side: side.into(),
+ }
+ })
+ .collect();
+ let schema = f.schema().try_into()?;
+ Ok(protobuf::JoinFilter {
+ expression: Some(expression),
+ column_indices,
+ schema: Some(schema),
+ })
+ })
+ .map_or(Ok(None), |v: Result<protobuf::JoinFilter>|
v.map(Some))?;
+
+ Ok(protobuf::PhysicalPlanNode {
+ physical_plan_type:
Some(PhysicalPlanType::NestedLoopJoin(Box::new(
+ protobuf::NestedLoopJoinExecNode {
+ left: Some(Box::new(left)),
+ right: Some(Box::new(right)),
+ join_type: join_type.into(),
+ filter,
+ },
+ ))),
+ })
} else {
let mut buf: Vec<u8> = vec![];
match extension_codec.try_encode(plan_clone.clone(), &mut buf) {
@@ -1297,7 +1398,7 @@ mod roundtrip_tests {
expressions::{binary, col, lit, NotExpr},
expressions::{Avg, Column, DistinctCount, PhysicalSortExpr},
filter::FilterExec,
- joins::{HashJoinExec, PartitionMode},
+ joins::{HashJoinExec, NestedLoopJoinExec, PartitionMode},
limit::{GlobalLimitExec, LocalLimitExec},
sorts::sort::SortExec,
AggregateExpr, ExecutionPlan, PhysicalExpr, Statistics,
@@ -1433,6 +1534,34 @@ mod roundtrip_tests {
Ok(())
}
+ #[test]
+ fn roundtrip_nested_loop_join() -> Result<()> {
+ let field_a = Field::new("col", DataType::Int64, false);
+ let schema_left = Schema::new(vec![field_a.clone()]);
+ let schema_right = Schema::new(vec![field_a]);
+
+ let schema_left = Arc::new(schema_left);
+ let schema_right = Arc::new(schema_right);
+ for join_type in &[
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftAnti,
+ JoinType::RightAnti,
+ JoinType::LeftSemi,
+ JoinType::RightSemi,
+ ] {
+ roundtrip_test(Arc::new(NestedLoopJoinExec::try_new(
+ Arc::new(EmptyExec::new(false, schema_left.clone())),
+ Arc::new(EmptyExec::new(false, schema_right.clone())),
+ None,
+ join_type,
+ )?))?;
+ }
+ Ok(())
+ }
+
#[test]
fn rountrip_aggregate() -> Result<()> {
let field_a = Field::new("a", DataType::Int64, false);