This is an automated email from the ASF dual-hosted git repository.

mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 02a470f606 Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet 
(#6750)
02a470f606 is described below

commit 02a470f6061cce8ee8e57f7af8a6a0e0ddc1571b
Author: Armin Primadi <[email protected]>
AuthorDate: Tue Jul 4 20:00:33 2023 +0700

    Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet (#6750)
    
    * Use JoinSet in MemTable
    
    * Fix error handling
    
    * Refactor AbortOnDropSingle in csv physical plan
    
    * Fix csv write physical plan error propagation
    
    * Refactor json write physical plan to use JoinSet
    
    * Refactor parquet write physical plan to use JoinSet
    
    * Refactor collect_partitioned to use JoinSet
    
    * Refactor pull_from_input method to make it easier to read
    
    * Fix typo
---
 datafusion/core/src/datasource/memory.rs           | 39 +++++++++++---------
 .../core/src/datasource/physical_plan/csv.rs       | 32 +++++++++-------
 .../core/src/datasource/physical_plan/json.rs      | 32 +++++++++-------
 .../core/src/datasource/physical_plan/parquet.rs   | 43 ++++++++++++----------
 datafusion/core/src/physical_plan/mod.rs           | 41 +++++++++++++++------
 .../core/src/physical_plan/repartition/mod.rs      | 33 ++++++++---------
 6 files changed, 127 insertions(+), 93 deletions(-)

diff --git a/datafusion/core/src/datasource/memory.rs 
b/datafusion/core/src/datasource/memory.rs
index 784aa2aff2..5398bb0903 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -29,12 +29,12 @@ use async_trait::async_trait;
 use datafusion_common::SchemaExt;
 use datafusion_execution::TaskContext;
 use tokio::sync::RwLock;
+use tokio::task::JoinSet;
 
 use crate::datasource::{TableProvider, TableType};
 use crate::error::{DataFusionError, Result};
 use crate::execution::context::SessionState;
 use crate::logical_expr::Expr;
-use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::insert::{DataSink, InsertExec};
 use crate::physical_plan::memory::MemoryExec;
 use crate::physical_plan::{common, SendableRecordBatchStream};
@@ -89,26 +89,31 @@ impl MemTable {
         let exec = t.scan(state, None, &[], None).await?;
         let partition_count = exec.output_partitioning().partition_count();
 
-        let tasks = (0..partition_count)
-            .map(|part_i| {
-                let task = state.task_ctx();
-                let exec = exec.clone();
-                let task = tokio::spawn(async move {
-                    let stream = exec.execute(part_i, task)?;
-                    common::collect(stream).await
-                });
-
-                AbortOnDropSingle::new(task)
-            })
-            // this collect *is needed* so that the join below can
-            // switch between tasks
-            .collect::<Vec<_>>();
+        let mut join_set = JoinSet::new();
+
+        for part_idx in 0..partition_count {
+            let task = state.task_ctx();
+            let exec = exec.clone();
+            join_set.spawn(async move {
+                let stream = exec.execute(part_idx, task)?;
+                common::collect(stream).await
+            });
+        }
 
         let mut data: Vec<Vec<RecordBatch>> =
             Vec::with_capacity(exec.output_partitioning().partition_count());
 
-        for result in futures::future::join_all(tasks).await {
-            data.push(result.map_err(|e| 
DataFusionError::External(Box::new(e)))??)
+        while let Some(result) = join_set.join_next().await {
+            match result {
+                Ok(res) => data.push(res?),
+                Err(e) => {
+                    if e.is_panic() {
+                        std::panic::resume_unwind(e.into_panic());
+                    } else {
+                        unreachable!();
+                    }
+                }
+            }
         }
 
         let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs 
b/datafusion/core/src/datasource/physical_plan/csv.rs
index 027bd1945b..eba51615cd 100644
--- a/datafusion/core/src/datasource/physical_plan/csv.rs
+++ b/datafusion/core/src/datasource/physical_plan/csv.rs
@@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{
 };
 use crate::datasource::physical_plan::FileMeta;
 use crate::error::{DataFusionError, Result};
-use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::expressions::PhysicalSortExpr;
 use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
 use crate::physical_plan::{
@@ -46,7 +45,7 @@ use std::fs;
 use std::path::Path;
 use std::sync::Arc;
 use std::task::Poll;
-use tokio::task::{self, JoinHandle};
+use tokio::task::JoinSet;
 
 /// Execution plan for scanning a CSV file
 #[derive(Debug, Clone)]
@@ -331,7 +330,7 @@ pub async fn plan_to_csv(
         )));
     }
 
-    let mut tasks = vec![];
+    let mut join_set = JoinSet::new();
     for i in 0..plan.output_partitioning().partition_count() {
         let plan = plan.clone();
         let filename = format!("part-{i}.csv");
@@ -340,22 +339,29 @@ pub async fn plan_to_csv(
         let mut writer = csv::Writer::new(file);
         let stream = plan.execute(i, task_ctx.clone())?;
 
-        let handle: JoinHandle<Result<()>> = task::spawn(async move {
-            stream
+        join_set.spawn(async move {
+            let result: Result<()> = stream
                 .map(|batch| writer.write(&batch?))
                 .try_collect()
                 .await
-                .map_err(DataFusionError::from)
+                .map_err(DataFusionError::from);
+            result
         });
-        tasks.push(AbortOnDropSingle::new(handle));
     }
 
-    futures::future::join_all(tasks)
-        .await
-        .into_iter()
-        .try_for_each(|result| {
-            result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
-        })?;
+    while let Some(result) = join_set.join_next().await {
+        match result {
+            Ok(res) => res?, // propagate DataFusion error
+            Err(e) => {
+                if e.is_panic() {
+                    std::panic::resume_unwind(e.into_panic());
+                } else {
+                    unreachable!();
+                }
+            }
+        }
+    }
+
     Ok(())
 }
 
diff --git a/datafusion/core/src/datasource/physical_plan/json.rs 
b/datafusion/core/src/datasource/physical_plan/json.rs
index b736fd7839..64f7077660 100644
--- a/datafusion/core/src/datasource/physical_plan/json.rs
+++ b/datafusion/core/src/datasource/physical_plan/json.rs
@@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{
 };
 use crate::datasource::physical_plan::FileMeta;
 use crate::error::{DataFusionError, Result};
-use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::expressions::PhysicalSortExpr;
 use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
 use crate::physical_plan::{
@@ -44,7 +43,7 @@ use std::io::BufReader;
 use std::path::Path;
 use std::sync::Arc;
 use std::task::Poll;
-use tokio::task::{self, JoinHandle};
+use tokio::task::JoinSet;
 
 use super::FileScanConfig;
 
@@ -266,7 +265,7 @@ pub async fn plan_to_json(
         )));
     }
 
-    let mut tasks = vec![];
+    let mut join_set = JoinSet::new();
     for i in 0..plan.output_partitioning().partition_count() {
         let plan = plan.clone();
         let filename = format!("part-{i}.json");
@@ -274,22 +273,29 @@ pub async fn plan_to_json(
         let file = fs::File::create(path)?;
         let mut writer = json::LineDelimitedWriter::new(file);
         let stream = plan.execute(i, task_ctx.clone())?;
-        let handle: JoinHandle<Result<()>> = task::spawn(async move {
-            stream
+        join_set.spawn(async move {
+            let result: Result<()> = stream
                 .map(|batch| writer.write(&batch?))
                 .try_collect()
                 .await
-                .map_err(DataFusionError::from)
+                .map_err(DataFusionError::from);
+            result
         });
-        tasks.push(AbortOnDropSingle::new(handle));
     }
 
-    futures::future::join_all(tasks)
-        .await
-        .into_iter()
-        .try_for_each(|result| {
-            result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
-        })?;
+    while let Some(result) = join_set.join_next().await {
+        match result {
+            Ok(res) => res?, // propagate DataFusion error
+            Err(e) => {
+                if e.is_panic() {
+                    std::panic::resume_unwind(e.into_panic());
+                } else {
+                    unreachable!();
+                }
+            }
+        }
+    }
+
     Ok(())
 }
 
diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs 
b/datafusion/core/src/datasource/physical_plan/parquet.rs
index f538255bc2..96e5ce9fa0 100644
--- a/datafusion/core/src/datasource/physical_plan/parquet.rs
+++ b/datafusion/core/src/datasource/physical_plan/parquet.rs
@@ -31,7 +31,6 @@ use crate::{
     execution::context::TaskContext,
     physical_optimizer::pruning::PruningPredicate,
     physical_plan::{
-        common::AbortOnDropSingle,
         metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
         ordering_equivalence_properties_helper, DisplayFormatType, 
ExecutionPlan,
         Partitioning, SendableRecordBatchStream, Statistics,
@@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter, 
ParquetRecordBatchStreamBuilder, ProjectionMas
 use parquet::basic::{ConvertedType, LogicalType};
 use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties};
 use parquet::schema::types::ColumnDescriptor;
+use tokio::task::JoinSet;
 
 mod metrics;
 pub mod page_filter;
@@ -701,7 +701,7 @@ pub async fn plan_to_parquet(
         )));
     }
 
-    let mut tasks = vec![];
+    let mut join_set = JoinSet::new();
     for i in 0..plan.output_partitioning().partition_count() {
         let plan = plan.clone();
         let filename = format!("part-{i}.parquet");
@@ -710,27 +710,30 @@ pub async fn plan_to_parquet(
         let mut writer =
             ArrowWriter::try_new(file, plan.schema(), 
writer_properties.clone())?;
         let stream = plan.execute(i, task_ctx.clone())?;
-        let handle: tokio::task::JoinHandle<Result<()>> =
-            tokio::task::spawn(async move {
-                stream
-                    .map(|batch| {
-                        
writer.write(&batch?).map_err(DataFusionError::ParquetError)
-                    })
-                    .try_collect()
-                    .await
-                    .map_err(DataFusionError::from)?;
+        join_set.spawn(async move {
+            stream
+                .map(|batch| 
writer.write(&batch?).map_err(DataFusionError::ParquetError))
+                .try_collect()
+                .await
+                .map_err(DataFusionError::from)?;
+
+            writer.close().map_err(DataFusionError::from).map(|_| ())
+        });
+    }
 
-                writer.close().map_err(DataFusionError::from).map(|_| ())
-            });
-        tasks.push(AbortOnDropSingle::new(handle));
+    while let Some(result) = join_set.join_next().await {
+        match result {
+            Ok(res) => res?,
+            Err(e) => {
+                if e.is_panic() {
+                    std::panic::resume_unwind(e.into_panic());
+                } else {
+                    unreachable!();
+                }
+            }
+        }
     }
 
-    futures::future::join_all(tasks)
-        .await
-        .into_iter()
-        .try_for_each(|result| {
-            result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
-        })?;
     Ok(())
 }
 
diff --git a/datafusion/core/src/physical_plan/mod.rs 
b/datafusion/core/src/physical_plan/mod.rs
index 5abecf6b16..7efd5a19ee 100644
--- a/datafusion/core/src/physical_plan/mod.rs
+++ b/datafusion/core/src/physical_plan/mod.rs
@@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs, 
DisplayFormatType, VerboseDisplay};
 use futures::stream::{Stream, TryStreamExt};
 use std::fmt;
 use std::fmt::Debug;
+use tokio::task::JoinSet;
 
 use datafusion_common::tree_node::Transformed;
 use datafusion_common::DataFusionError;
@@ -445,20 +446,37 @@ pub async fn collect_partitioned(
 ) -> Result<Vec<Vec<RecordBatch>>> {
     let streams = execute_stream_partitioned(plan, context)?;
 
+    let mut join_set = JoinSet::new();
     // Execute the plan and collect the results into batches.
-    let handles = streams
-        .into_iter()
-        .enumerate()
-        .map(|(idx, stream)| async move {
-            let handle = tokio::task::spawn(stream.try_collect());
-            AbortOnDropSingle::new(handle).await.map_err(|e| {
-                DataFusionError::Execution(format!(
-                    "collect_partitioned partition {idx} panicked: {e}"
-                ))
-            })?
+    streams.into_iter().enumerate().for_each(|(idx, stream)| {
+        join_set.spawn(async move {
+            let result: Result<Vec<RecordBatch>> = stream.try_collect().await;
+            (idx, result)
         });
+    });
+
+    let mut batches = vec![];
+    // Note that currently this doesn't identify the thread that panicked
+    //
+    // TODO: Replace with 
[join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id
+    // once it is stable
+    while let Some(result) = join_set.join_next().await {
+        match result {
+            Ok((idx, res)) => batches.push((idx, res?)),
+            Err(e) => {
+                if e.is_panic() {
+                    std::panic::resume_unwind(e.into_panic());
+                } else {
+                    unreachable!();
+                }
+            }
+        }
+    }
+
+    batches.sort_by_key(|(idx, _)| *idx);
+    let batches = batches.into_iter().map(|(_, batch)| batch).collect();
 
-    futures::future::try_join_all(handles).await
+    Ok(batches)
 }
 
 /// Execute the [ExecutionPlan] and return a vec with one stream per output 
partition
@@ -713,7 +731,6 @@ pub mod unnest;
 pub mod values;
 pub mod windows;
 
-use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::repartition::RepartitionExec;
 use 
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
 use datafusion_execution::TaskContext;
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs 
b/datafusion/core/src/physical_plan/repartition/mod.rs
index 85225eb471..3c689e97ab 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -263,7 +263,7 @@ struct RepartitionMetrics {
     /// Time in nanos to execute child operator and fetch batches
     fetch_time: metrics::Time,
     /// Time in nanos to perform repartitioning
-    repart_time: metrics::Time,
+    repartition_time: metrics::Time,
     /// Time in nanos for sending resulting batches to channels
     send_time: metrics::Time,
 }
@@ -293,7 +293,7 @@ impl RepartitionMetrics {
 
         Self {
             fetch_time,
-            repart_time,
+            repartition_time: repart_time,
             send_time,
         }
     }
@@ -407,7 +407,7 @@ impl ExecutionPlan for RepartitionExec {
                 // note we use a custom channel that ensures there is always 
data for each receiver
                 // but limits the amount of buffering if required.
                 let (txs, rxs) = channels(num_output_partitions);
-                // Clone sender for ech input partitions
+                // Clone sender for each input partitions
                 let txs = txs
                     .into_iter()
                     .map(|item| vec![item; num_input_partitions])
@@ -565,34 +565,31 @@ impl RepartitionExec {
     /// Pulls data from the specified input plan, feeding it to the
     /// output partitions based on the desired partitioning
     ///
-    /// i is the input partition index
-    ///
     /// txs hold the output sending channels for each output partition
     async fn pull_from_input(
         input: Arc<dyn ExecutionPlan>,
-        i: usize,
-        mut txs: HashMap<
+        partition: usize,
+        mut output_channels: HashMap<
             usize,
             (DistributionSender<MaybeBatch>, SharedMemoryReservation),
         >,
         partitioning: Partitioning,
-        r_metrics: RepartitionMetrics,
+        metrics: RepartitionMetrics,
         context: Arc<TaskContext>,
     ) -> Result<()> {
         let mut partitioner =
-            BatchPartitioner::try_new(partitioning, 
r_metrics.repart_time.clone())?;
+            BatchPartitioner::try_new(partitioning, 
metrics.repartition_time.clone())?;
 
         // execute the child operator
-        let timer = r_metrics.fetch_time.timer();
-        let mut stream = input.execute(i, context)?;
+        let timer = metrics.fetch_time.timer();
+        let mut stream = input.execute(partition, context)?;
         timer.done();
 
-        // While there are still outputs to send to, keep
-        // pulling inputs
+        // While there are still outputs to send to, keep pulling inputs
         let mut batches_until_yield = partitioner.num_partitions();
-        while !txs.is_empty() {
+        while !output_channels.is_empty() {
             // fetch the next batch
-            let timer = r_metrics.fetch_time.timer();
+            let timer = metrics.fetch_time.timer();
             let result = stream.next().await;
             timer.done();
 
@@ -606,15 +603,15 @@ impl RepartitionExec {
                 let (partition, batch) = res?;
                 let size = batch.get_array_memory_size();
 
-                let timer = r_metrics.send_time.timer();
+                let timer = metrics.send_time.timer();
                 // if there is still a receiver, send to it
-                if let Some((tx, reservation)) = txs.get_mut(&partition) {
+                if let Some((tx, reservation)) = 
output_channels.get_mut(&partition) {
                     reservation.lock().try_grow(size)?;
 
                     if tx.send(Some(Ok(batch))).await.is_err() {
                         // If the other end has hung up, it was an early 
shutdown (e.g. LIMIT)
                         reservation.lock().shrink(size);
-                        txs.remove(&partition);
+                        output_channels.remove(&partition);
                     }
                 }
                 timer.done();

Reply via email to