This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new f8405859 add partitioning scheme for unresolved shuffle and shuffle 
reader exec (#1144)
f8405859 is described below

commit f840585948c31e841d0464b9bf14d1276b2f27b6
Author: Onur Satici <[email protected]>
AuthorDate: Thu Dec 19 14:57:06 2024 +0000

    add partitioning scheme for unresolved shuffle and shuffle reader exec 
(#1144)
    
    * add partitioning scheme for unresolved shuffle and shuffle reader exec
    
    * make default_codec a property of BallistaPhysicalExtensionCodec
    
    * tests
---
 ballista/core/proto/ballista.proto                 |   3 +-
 .../core/src/execution_plans/shuffle_reader.rs     |   7 +-
 .../core/src/execution_plans/unresolved_shuffle.rs |  18 ++-
 ballista/core/src/serde/generated/ballista.rs      |   6 +-
 ballista/core/src/serde/mod.rs                     | 146 ++++++++++++++++++---
 ballista/scheduler/src/planner.rs                  |  31 +++--
 6 files changed, 171 insertions(+), 40 deletions(-)

diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index cb3c148b..70e8bba5 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -50,7 +50,7 @@ message ShuffleWriterExecNode {
 message UnresolvedShuffleExecNode {
   uint32 stage_id = 1;
   datafusion_common.Schema schema = 2;
-  uint32 output_partition_count = 4;
+  datafusion.Partitioning partitioning = 5;
 }
 
 message ShuffleReaderExecNode {
@@ -58,6 +58,7 @@ message ShuffleReaderExecNode {
   datafusion_common.Schema schema = 2;
   // The stage to read from
   uint32 stage_id = 3;
+  datafusion.Partitioning partitioning = 4;
 }
 
 message ShuffleReaderPartition {
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 4a2c25b8..f50d6a29 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -74,12 +74,11 @@ impl ShuffleReaderExec {
         stage_id: usize,
         partition: Vec<Vec<PartitionLocation>>,
         schema: SchemaRef,
+        partitioning: Partitioning,
     ) -> Result<Self> {
         let properties = PlanProperties::new(
             
datafusion::physical_expr::EquivalenceProperties::new(schema.clone()),
-            // TODO partitioning may be known and could be populated here
-            // see https://github.com/apache/arrow-datafusion/issues/758
-            Partitioning::UnknownPartitioning(partition.len()),
+            partitioning,
             datafusion::physical_plan::ExecutionMode::Bounded,
         );
         Ok(Self {
@@ -134,6 +133,7 @@ impl ExecutionPlan for ShuffleReaderExec {
             self.stage_id,
             self.partition.clone(),
             self.schema.clone(),
+            self.properties().output_partitioning().clone(),
         )?))
     }
 
@@ -553,6 +553,7 @@ mod tests {
             input_stage_id,
             vec![partitions],
             Arc::new(schema),
+            Partitioning::UnknownPartitioning(4),
         )?;
         let mut stream = shuffle_reader_exec.execute(0, task_ctx)?;
         let batches = utils::collect_stream(&mut stream).await;
diff --git a/ballista/core/src/execution_plans/unresolved_shuffle.rs 
b/ballista/core/src/execution_plans/unresolved_shuffle.rs
index e227e2ac..9d4d3077 100644
--- a/ballista/core/src/execution_plans/unresolved_shuffle.rs
+++ b/ballista/core/src/execution_plans/unresolved_shuffle.rs
@@ -46,22 +46,16 @@ pub struct UnresolvedShuffleExec {
 
 impl UnresolvedShuffleExec {
     /// Create a new UnresolvedShuffleExec
-    pub fn new(
-        stage_id: usize,
-        schema: SchemaRef,
-        output_partition_count: usize,
-    ) -> Self {
+    pub fn new(stage_id: usize, schema: SchemaRef, partitioning: Partitioning) 
-> Self {
         let properties = PlanProperties::new(
             
datafusion::physical_expr::EquivalenceProperties::new(schema.clone()),
-            // TODO the output partition is known and should be populated here!
-            // see https://github.com/apache/arrow-datafusion/issues/758
-            Partitioning::UnknownPartitioning(output_partition_count),
+            partitioning,
             datafusion::physical_plan::ExecutionMode::Bounded,
         );
         Self {
             stage_id,
             schema,
-            output_partition_count,
+            output_partition_count: properties.partitioning.partition_count(),
             properties,
         }
     }
@@ -75,7 +69,11 @@ impl DisplayAs for UnresolvedShuffleExec {
     ) -> std::fmt::Result {
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
-                write!(f, "UnresolvedShuffleExec")
+                write!(
+                    f,
+                    "UnresolvedShuffleExec: {:?}",
+                    self.properties().output_partitioning()
+                )
             }
         }
     }
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index d4faef82..ed73ab40 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -42,8 +42,8 @@ pub struct UnresolvedShuffleExecNode {
     pub stage_id: u32,
     #[prost(message, optional, tag = "2")]
     pub schema: ::core::option::Option<::datafusion_proto_common::Schema>,
-    #[prost(uint32, tag = "4")]
-    pub output_partition_count: u32,
+    #[prost(message, optional, tag = "5")]
+    pub partitioning: 
::core::option::Option<::datafusion_proto::protobuf::Partitioning>,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ShuffleReaderExecNode {
@@ -54,6 +54,8 @@ pub struct ShuffleReaderExecNode {
     /// The stage to read from
     #[prost(uint32, tag = "3")]
     pub stage_id: u32,
+    #[prost(message, optional, tag = "4")]
+    pub partitioning: 
::core::option::Option<::datafusion_proto::protobuf::Partitioning>,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ShuffleReaderPartition {
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index d7d6474f..84cf8068 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -21,6 +21,7 @@
 use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
 
 use arrow_flight::sql::ProstMessageExt;
+use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::common::{DataFusionError, Result};
 use datafusion::execution::FunctionRegistry;
 use datafusion::physical_plan::{ExecutionPlan, Partitioning};
@@ -29,6 +30,9 @@ use datafusion_proto::logical_plan::file_formats::{
     JsonLogicalExtensionCodec, ParquetLogicalExtensionCodec,
 };
 use 
datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
+use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
+use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
 use datafusion_proto::protobuf::proto_error;
 use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
 use datafusion_proto::{
@@ -244,8 +248,18 @@ impl LogicalExtensionCodec for 
BallistaLogicalExtensionCodec {
     }
 }
 
-#[derive(Debug, Default)]
-pub struct BallistaPhysicalExtensionCodec {}
+#[derive(Debug)]
+pub struct BallistaPhysicalExtensionCodec {
+    default_codec: Arc<dyn PhysicalExtensionCodec>,
+}
+
+impl Default for BallistaPhysicalExtensionCodec {
+    fn default() -> Self {
+        Self {
+            default_codec: Arc::new(DefaultPhysicalExtensionCodec {}),
+        }
+    }
+}
 
 impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec {
     fn try_decode(
@@ -272,14 +286,11 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
             PhysicalPlanType::ShuffleWriter(shuffle_writer) => {
                 let input = inputs[0].clone();
 
-                let default_codec =
-                    
datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {};
-
                 let shuffle_output_partitioning = 
parse_protobuf_hash_partitioning(
                     shuffle_writer.output_partitioning.as_ref(),
                     registry,
                     input.schema().as_ref(),
-                    &default_codec,
+                    self.default_codec.as_ref(),
                 )?;
 
                 Ok(Arc::new(ShuffleWriterExec::try_new(
@@ -292,7 +303,8 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
             }
             PhysicalPlanType::ShuffleReader(shuffle_reader) => {
                 let stage_id = shuffle_reader.stage_id as usize;
-                let schema = 
Arc::new(convert_required!(shuffle_reader.schema)?);
+                let schema: SchemaRef =
+                    Arc::new(convert_required!(shuffle_reader.schema)?);
                 let partition_location: Vec<Vec<PartitionLocation>> = 
shuffle_reader
                     .partition
                     .iter()
@@ -309,16 +321,37 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
                             .collect::<Result<Vec<_>, _>>()
                     })
                     .collect::<Result<Vec<_>, DataFusionError>>()?;
-                let shuffle_reader =
-                    ShuffleReaderExec::try_new(stage_id, partition_location, 
schema)?;
+                let partitioning = parse_protobuf_partitioning(
+                    shuffle_reader.partitioning.as_ref(),
+                    registry,
+                    schema.as_ref(),
+                    self.default_codec.as_ref(),
+                )?;
+                let partitioning = partitioning
+                    .ok_or_else(|| proto_error("missing required partitioning 
field"))?;
+                let shuffle_reader = ShuffleReaderExec::try_new(
+                    stage_id,
+                    partition_location,
+                    schema,
+                    partitioning,
+                )?;
                 Ok(Arc::new(shuffle_reader))
             }
             PhysicalPlanType::UnresolvedShuffle(unresolved_shuffle) => {
-                let schema = 
Arc::new(convert_required!(unresolved_shuffle.schema)?);
+                let schema: SchemaRef =
+                    Arc::new(convert_required!(unresolved_shuffle.schema)?);
+                let partitioning = parse_protobuf_partitioning(
+                    unresolved_shuffle.partitioning.as_ref(),
+                    registry,
+                    schema.as_ref(),
+                    self.default_codec.as_ref(),
+                )?;
+                let partitioning = partitioning
+                    .ok_or_else(|| proto_error("missing required partitioning 
field"))?;
                 Ok(Arc::new(UnresolvedShuffleExec::new(
                     unresolved_shuffle.stage_id as usize,
                     schema,
-                    unresolved_shuffle.output_partition_count as usize,
+                    partitioning,
                 )))
             }
         }
@@ -334,12 +367,10 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
             // to get the true output partitioning
             let output_partitioning = match exec.shuffle_output_partitioning() 
{
                 Some(Partitioning::Hash(exprs, partition_count)) => {
-                    let default_codec =
-                        
datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec {};
                     Some(datafusion_proto::protobuf::PhysicalHashRepartition {
                         hash_expr: exprs
                             .iter()
-                            
.map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(),
 &default_codec))
+                            
.map(|expr|datafusion_proto::physical_plan::to_proto::serialize_physical_expr(&expr.clone(),
 self.default_codec.as_ref()))
                             .collect::<Result<Vec<_>, DataFusionError>>()?,
                         partition_count: *partition_count as u64,
                     })
@@ -387,12 +418,17 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
                         .collect::<Result<Vec<_>, _>>()?,
                 });
             }
+            let partitioning = serialize_partitioning(
+                &exec.properties().partitioning,
+                self.default_codec.as_ref(),
+            )?;
             let proto = protobuf::BallistaPhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::ShuffleReader(
                     protobuf::ShuffleReaderExecNode {
                         stage_id,
                         partition,
                         schema: Some(exec.schema().as_ref().try_into()?),
+                        partitioning: Some(partitioning),
                     },
                 )),
             };
@@ -404,12 +440,16 @@ impl PhysicalExtensionCodec for 
BallistaPhysicalExtensionCodec {
 
             Ok(())
         } else if let Some(exec) = 
node.as_any().downcast_ref::<UnresolvedShuffleExec>() {
+            let partitioning = serialize_partitioning(
+                &exec.properties().partitioning,
+                self.default_codec.as_ref(),
+            )?;
             let proto = protobuf::BallistaPhysicalPlanNode {
                 physical_plan_type: Some(PhysicalPlanType::UnresolvedShuffle(
                     protobuf::UnresolvedShuffleExecNode {
                         stage_id: exec.stage_id as u32,
                         schema: Some(exec.schema().as_ref().try_into()?),
-                        output_partition_count: exec.output_partition_count as 
u32,
+                        partitioning: Some(partitioning),
                     },
                 )),
             };
@@ -449,6 +489,11 @@ struct FileFormatProto {
 
 #[cfg(test)]
 mod test {
+    use super::*;
+    use datafusion::arrow::datatypes::{DataType, Field, Schema};
+    use datafusion::execution::registry::MemoryFunctionRegistry;
+    use datafusion::physical_plan::expressions::col;
+    use datafusion::physical_plan::Partitioning;
     use datafusion::{
         common::DFSchema,
         datasource::file_format::{parquet::ParquetFormatFactory, 
DefaultFileType},
@@ -493,4 +538,75 @@ mod test {
         assert_eq!(o.to_string(), d.to_string())
         //logical_plan.
     }
+
+    fn create_test_schema() -> SchemaRef {
+        Arc::new(Schema::new(vec![
+            Field::new("id", DataType::Int32, false),
+            Field::new("name", DataType::Utf8, false),
+        ]))
+    }
+
+    #[tokio::test]
+    async fn test_unresolved_shuffle_exec_roundtrip() {
+        let schema = create_test_schema();
+        let partitioning =
+            Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4);
+
+        let original_exec = UnresolvedShuffleExec::new(
+            1, // stage_id
+            schema.clone(),
+            partitioning.clone(),
+        );
+
+        let codec = BallistaPhysicalExtensionCodec::default();
+        let mut buf: Vec<u8> = vec![];
+        codec
+            .try_encode(Arc::new(original_exec.clone()), &mut buf)
+            .unwrap();
+
+        let registry = MemoryFunctionRegistry::new();
+        let decoded_plan = codec.try_decode(&buf, &[], &registry).unwrap();
+
+        let decoded_exec = decoded_plan
+            .as_any()
+            .downcast_ref::<UnresolvedShuffleExec>()
+            .expect("Expected UnresolvedShuffleExec");
+
+        assert_eq!(decoded_exec.stage_id, 1);
+        assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref());
+        assert_eq!(&decoded_exec.properties().partitioning, &partitioning);
+    }
+
+    #[tokio::test]
+    async fn test_shuffle_reader_exec_roundtrip() {
+        let schema = create_test_schema();
+        let partitioning =
+            Partitioning::Hash(vec![col("id", schema.as_ref()).unwrap()], 4);
+
+        let original_exec = ShuffleReaderExec::try_new(
+            1, // stage_id
+            Vec::new(),
+            schema.clone(),
+            partitioning.clone(),
+        )
+        .unwrap();
+
+        let codec = BallistaPhysicalExtensionCodec::default();
+        let mut buf: Vec<u8> = vec![];
+        codec
+            .try_encode(Arc::new(original_exec.clone()), &mut buf)
+            .unwrap();
+
+        let registry = MemoryFunctionRegistry::new();
+        let decoded_plan = codec.try_decode(&buf, &[], &registry).unwrap();
+
+        let decoded_exec = decoded_plan
+            .as_any()
+            .downcast_ref::<ShuffleReaderExec>()
+            .expect("Expected ShuffleReaderExec");
+
+        assert_eq!(decoded_exec.stage_id, 1);
+        assert_eq!(decoded_exec.schema().as_ref(), schema.as_ref());
+        assert_eq!(&decoded_exec.properties().partitioning, &partitioning);
+    }
 }
diff --git a/ballista/scheduler/src/planner.rs 
b/ballista/scheduler/src/planner.rs
index 47500ac1..fc32262e 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -168,10 +168,7 @@ fn create_unresolved_shuffle(
     Arc::new(UnresolvedShuffleExec::new(
         shuffle_writer.stage_id(),
         shuffle_writer.schema(),
-        shuffle_writer
-            .properties()
-            .output_partitioning()
-            .partition_count(),
+        shuffle_writer.properties().output_partitioning().clone(),
     ))
 }
 
@@ -239,6 +236,10 @@ pub fn remove_unresolved_shuffles(
                 unresolved_shuffle.stage_id,
                 relevant_locations,
                 unresolved_shuffle.schema().clone(),
+                unresolved_shuffle
+                    .properties()
+                    .output_partitioning()
+                    .clone(),
             )?))
         } else {
             new_children.push(remove_unresolved_shuffles(
@@ -259,16 +260,12 @@ pub fn rollback_resolved_shuffles(
     let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
     for child in stage.children() {
         if let Some(shuffle_reader) = 
child.as_any().downcast_ref::<ShuffleReaderExec>() {
-            let output_partition_count = shuffle_reader
-                .properties()
-                .output_partitioning()
-                .partition_count();
             let stage_id = shuffle_reader.stage_id;
 
             let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
                 stage_id,
                 shuffle_reader.schema(),
-                output_partition_count,
+                shuffle_reader.properties().partitioning.clone(),
             ));
             new_children.push(unresolved_shuffle);
         } else {
@@ -396,6 +393,10 @@ mod test {
             downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle.stage_id, 1);
         assert_eq!(unresolved_shuffle.output_partition_count, 2);
+        assert_eq!(
+            unresolved_shuffle.properties().partitioning,
+            Partitioning::Hash(vec![Arc::new(Column::new("l_returnflag", 0))], 
2)
+        );
 
         // verify stage 2
         let stage2 = stages[2].children()[0].clone();
@@ -405,6 +406,10 @@ mod test {
             downcast_exec!(unresolved_shuffle, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle.stage_id, 2);
         assert_eq!(unresolved_shuffle.output_partition_count, 2);
+        assert_eq!(
+            unresolved_shuffle.properties().partitioning,
+            Partitioning::Hash(vec![Arc::new(Column::new("l_returnflag", 0))], 
2)
+        );
 
         Ok(())
     }
@@ -559,6 +564,10 @@ order by
         let unresolved_shuffle_reader_1 =
             downcast_exec!(join_input_1, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle_reader_1.output_partition_count, 2);
+        assert_eq!(
+            unresolved_shuffle_reader_1.properties().partitioning,
+            Partitioning::Hash(vec![Arc::new(Column::new("l_orderkey", 0))], 2)
+        );
 
         let join_input_2 = join.children()[1].clone();
         // skip CoalesceBatches
@@ -566,6 +575,10 @@ order by
         let unresolved_shuffle_reader_2 =
             downcast_exec!(join_input_2, UnresolvedShuffleExec);
         assert_eq!(unresolved_shuffle_reader_2.output_partition_count, 2);
+        assert_eq!(
+            unresolved_shuffle_reader_2.properties().partitioning,
+            Partitioning::Hash(vec![Arc::new(Column::new("o_orderkey", 0))], 2)
+        );
 
         // final partitioned hash aggregate
         assert_eq!(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to