This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 14264d2c39 fix: use `JoinSet` to make spawned tasks cancel-safe (#9318)
14264d2c39 is described below

commit 14264d2c3947e432f71bfe0af1a3dbafbb6ee686
Author: Artem Medvedev <[email protected]>
AuthorDate: Tue Feb 27 14:11:59 2024 +0100

    fix: use `JoinSet` to make spawned tasks cancel-safe (#9318)
    
    * fix: use `JoinSet` to make spawned tasks cancel-safe
    
    * feat: drop `AbortOnDropSingle` and `AbortOnDropMany`
    
    * style: doc lint
    
    * fix: ordering of the tasks in `RepartitionExec`
    
    * fix: replace spawn_blocking with JoinSet
    
    * style: disallow spawn methods
    
    * fixes: preserve ordering of tasks
    
    * style: allow spawning in tests
    
    * chore: exclude clippy.toml from rat
    
    * chore: typo
    
    * feat: introduce `SpawnedTask`
    
    * revert outdated comment
    
    * switch to SpawnedTask missed outdated part
    
    * doc: improve reason for disallowed-method
---
 clippy.toml                                        |  4 ++
 datafusion/core/src/dataframe/mod.rs               |  1 +
 .../core/src/datasource/file_format/arrow.rs       |  2 +-
 .../core/src/datasource/file_format/parquet.rs     | 51 ++++++++-------
 .../core/src/datasource/file_format/write/demux.rs | 12 ++--
 .../datasource/file_format/write/orchestration.rs  | 29 ++++-----
 datafusion/core/src/datasource/stream.rs           |  9 ++-
 datafusion/core/src/execution/context/mod.rs       |  1 +
 datafusion/core/tests/fifo.rs                      |  2 +
 .../fuzz_cases/sort_preserving_repartition_fuzz.rs |  1 +
 datafusion/core/tests/fuzz_cases/window_fuzz.rs    |  1 +
 datafusion/physical-plan/src/common.rs             | 73 ++++++++++------------
 datafusion/physical-plan/src/lib.rs                | 10 +--
 datafusion/physical-plan/src/repartition/mod.rs    | 46 +++++++-------
 datafusion/physical-plan/src/sorts/sort.rs         |  7 +--
 datafusion/sqllogictest/bin/sqllogictests.rs       |  1 +
 dev/release/rat_exclude_files.txt                  |  3 +-
 17 files changed, 129 insertions(+), 124 deletions(-)

diff --git a/clippy.toml b/clippy.toml
new file mode 100644
index 0000000000..c6c754e440
--- /dev/null
+++ b/clippy.toml
@@ -0,0 +1,4 @@
+disallowed-methods = [
+    { path = "tokio::task::spawn", reason = "To provide cancel-safety, use 
`SpawnedTask::spawn` instead 
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+    { path = "tokio::task::spawn_blocking", reason = "To provide 
cancel-safety, use `SpawnedTask::spawn` instead 
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+]
diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 3a60d57f66..c04247210d 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2172,6 +2172,7 @@ mod tests {
     }
 
     #[tokio::test]
+    #[allow(clippy::disallowed_methods)]
     async fn sendable() {
         let df = test_table().await.unwrap();
         // dataframes should be sendable between threads/tasks
diff --git a/datafusion/core/src/datasource/file_format/arrow.rs 
b/datafusion/core/src/datasource/file_format/arrow.rs
index ead2db5a10..d5f07d11be 100644
--- a/datafusion/core/src/datasource/file_format/arrow.rs
+++ b/datafusion/core/src/datasource/file_format/arrow.rs
@@ -295,7 +295,7 @@ impl DataSink for ArrowFileSink {
             }
         }
 
-        match demux_task.await {
+        match demux_task.join().await {
             Ok(r) => r?,
             Err(e) => {
                 if e.is_panic() {
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs 
b/datafusion/core/src/datasource/file_format/parquet.rs
index 89ec81630c..7398501153 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -32,7 +32,7 @@ use std::fmt::Debug;
 use std::sync::Arc;
 use tokio::io::{AsyncWrite, AsyncWriteExt};
 use tokio::sync::mpsc::{self, Receiver, Sender};
-use tokio::task::{JoinHandle, JoinSet};
+use tokio::task::JoinSet;
 
 use crate::datasource::file_format::file_compression_type::FileCompressionType;
 use crate::datasource::statistics::{create_max_min_accs, get_col_stats};
@@ -42,6 +42,7 @@ use bytes::{BufMut, BytesMut};
 use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType};
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
+use datafusion_physical_plan::common::SpawnedTask;
 use futures::{StreamExt, TryStreamExt};
 use hashbrown::HashMap;
 use object_store::path::Path;
@@ -728,7 +729,7 @@ impl DataSink for ParquetSink {
             }
         }
 
-        match demux_task.await {
+        match demux_task.join().await {
             Ok(r) => r?,
             Err(e) => {
                 if e.is_panic() {
@@ -738,6 +739,7 @@ impl DataSink for ParquetSink {
                 }
             }
         }
+
         Ok(row_count as u64)
     }
 }
@@ -754,8 +756,9 @@ async fn column_serializer_task(
     Ok(writer)
 }
 
-type ColumnJoinHandle = JoinHandle<Result<ArrowColumnWriter>>;
+type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
 type ColSender = Sender<ArrowLeafColumn>;
+
 /// Spawns a parallel serialization task for each column
 /// Returns join handles for each columns serialization task along with a send 
channel
 /// to send arrow arrays to each serialization task.
@@ -763,23 +766,24 @@ fn spawn_column_parallel_row_group_writer(
     schema: Arc<Schema>,
     parquet_props: Arc<WriterProperties>,
     max_buffer_size: usize,
-) -> Result<(Vec<ColumnJoinHandle>, Vec<ColSender>)> {
+) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
     let schema_desc = arrow_to_parquet_schema(&schema)?;
     let col_writers = get_column_writers(&schema_desc, &parquet_props, 
&schema)?;
     let num_columns = col_writers.len();
 
-    let mut col_writer_handles = Vec::with_capacity(num_columns);
+    let mut col_writer_tasks = Vec::with_capacity(num_columns);
     let mut col_array_channels = Vec::with_capacity(num_columns);
     for writer in col_writers.into_iter() {
         // Buffer size of this channel limits the number of arrays queued up 
for column level serialization
         let (send_array, recieve_array) =
             mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
         col_array_channels.push(send_array);
-        col_writer_handles
-            .push(tokio::spawn(column_serializer_task(recieve_array, writer)))
+
+        let task = SpawnedTask::spawn(column_serializer_task(recieve_array, 
writer));
+        col_writer_tasks.push(task);
     }
 
-    Ok((col_writer_handles, col_array_channels))
+    Ok((col_writer_tasks, col_array_channels))
 }
 
 /// Settings related to writing parquet files in parallel
@@ -820,14 +824,14 @@ async fn send_arrays_to_col_writers(
 /// Spawns a tokio task which joins the parallel column writer tasks,
 /// and finalizes the row group
 fn spawn_rg_join_and_finalize_task(
-    column_writer_handles: Vec<JoinHandle<Result<ArrowColumnWriter>>>,
+    column_writer_tasks: Vec<ColumnWriterTask>,
     rg_rows: usize,
-) -> JoinHandle<RBStreamSerializeResult> {
-    tokio::spawn(async move {
-        let num_cols = column_writer_handles.len();
+) -> SpawnedTask<RBStreamSerializeResult> {
+    SpawnedTask::spawn(async move {
+        let num_cols = column_writer_tasks.len();
         let mut finalized_rg = Vec::with_capacity(num_cols);
-        for handle in column_writer_handles.into_iter() {
-            match handle.await {
+        for task in column_writer_tasks.into_iter() {
+            match task.join().await {
                 Ok(r) => {
                     let w = r?;
                     finalized_rg.push(w.close()?);
@@ -856,12 +860,12 @@ fn spawn_rg_join_and_finalize_task(
 /// given by n_columns * num_row_groups.
 fn spawn_parquet_parallel_serialization_task(
     mut data: Receiver<RecordBatch>,
-    serialize_tx: Sender<JoinHandle<RBStreamSerializeResult>>,
+    serialize_tx: Sender<SpawnedTask<RBStreamSerializeResult>>,
     schema: Arc<Schema>,
     writer_props: Arc<WriterProperties>,
     parallel_options: ParallelParquetWriterOptions,
-) -> JoinHandle<Result<(), DataFusionError>> {
-    tokio::spawn(async move {
+) -> SpawnedTask<Result<(), DataFusionError>> {
+    SpawnedTask::spawn(async move {
         let max_buffer_rb = 
parallel_options.max_buffered_record_batches_per_stream;
         let max_row_group_rows = writer_props.max_row_group_size();
         let (mut column_writer_handles, mut col_array_channels) =
@@ -931,7 +935,7 @@ fn spawn_parquet_parallel_serialization_task(
 /// Consume RowGroups serialized by other parallel tasks and concatenate them 
in
 /// to the final parquet file, while flushing finalized bytes to an 
[ObjectStore]
 async fn concatenate_parallel_row_groups(
-    mut serialize_rx: Receiver<JoinHandle<RBStreamSerializeResult>>,
+    mut serialize_rx: Receiver<SpawnedTask<RBStreamSerializeResult>>,
     schema: Arc<Schema>,
     writer_props: Arc<WriterProperties>,
     mut object_store_writer: AbortableWrite<Box<dyn AsyncWrite + Send + 
Unpin>>,
@@ -947,9 +951,8 @@ async fn concatenate_parallel_row_groups(
 
     let mut row_count = 0;
 
-    while let Some(handle) = serialize_rx.recv().await {
-        let join_result = handle.await;
-        match join_result {
+    while let Some(task) = serialize_rx.recv().await {
+        match task.join().await {
             Ok(result) => {
                 let mut rg_out = parquet_writer.next_row_group()?;
                 let (serialized_columns, cnt) = result?;
@@ -999,7 +1002,7 @@ async fn output_single_parquet_file_parallelized(
     let max_rowgroups = parallel_options.max_parallel_row_groups;
     // Buffer size of this channel limits maximum number of RowGroups being 
worked on in parallel
     let (serialize_tx, serialize_rx) =
-        mpsc::channel::<JoinHandle<RBStreamSerializeResult>>(max_rowgroups);
+        mpsc::channel::<SpawnedTask<RBStreamSerializeResult>>(max_rowgroups);
 
     let arc_props = Arc::new(parquet_props.clone());
     let launch_serialization_task = spawn_parquet_parallel_serialization_task(
@@ -1017,7 +1020,7 @@ async fn output_single_parquet_file_parallelized(
     )
     .await?;
 
-    match launch_serialization_task.await {
+    match launch_serialization_task.join().await {
         Ok(Ok(_)) => (),
         Ok(Err(e)) => return Err(e),
         Err(e) => {
@@ -1027,7 +1030,7 @@ async fn output_single_parquet_file_parallelized(
                 unreachable!()
             }
         }
-    };
+    }
 
     Ok(row_count)
 }
diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs 
b/datafusion/core/src/datasource/file_format/write/demux.rs
index 8bccf3d71c..d70b4811da 100644
--- a/datafusion/core/src/datasource/file_format/write/demux.rs
+++ b/datafusion/core/src/datasource/file_format/write/demux.rs
@@ -41,8 +41,8 @@ use object_store::path::Path;
 
 use rand::distributions::DistString;
 
+use datafusion_physical_plan::common::SpawnedTask;
 use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver, 
UnboundedSender};
-use tokio::task::JoinHandle;
 
 type RecordBatchReceiver = Receiver<RecordBatch>;
 type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
@@ -76,15 +76,15 @@ pub(crate) fn start_demuxer_task(
     partition_by: Option<Vec<(String, DataType)>>,
     base_output_path: ListingTableUrl,
     file_extension: String,
-) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) {
-    let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
+) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
+    let (tx, rx) = mpsc::unbounded_channel();
     let context = context.clone();
     let single_file_output = !base_output_path.is_collection();
-    let task: JoinHandle<std::result::Result<(), DataFusionError>> = match 
partition_by {
+    let task = match partition_by {
         Some(parts) => {
             // There could be an arbitrarily large number of parallel hive 
style partitions being written to, so we cannot
             // bound this channel without risking a deadlock.
-            tokio::spawn(async move {
+            SpawnedTask::spawn(async move {
                 hive_style_partitions_demuxer(
                     tx,
                     input,
@@ -96,7 +96,7 @@ pub(crate) fn start_demuxer_task(
                 .await
             })
         }
-        None => tokio::spawn(async move {
+        None => SpawnedTask::spawn(async move {
             row_count_demuxer(
                 tx,
                 input,
diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs 
b/datafusion/core/src/datasource/file_format/write/orchestration.rs
index 1a3042cbc0..05406d3751 100644
--- a/datafusion/core/src/datasource/file_format/write/orchestration.rs
+++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs
@@ -33,10 +33,11 @@ use datafusion_common::{internal_datafusion_err, 
internal_err, DataFusionError};
 use datafusion_execution::TaskContext;
 
 use bytes::Bytes;
+use datafusion_physical_plan::common::SpawnedTask;
+use futures::try_join;
 use tokio::io::{AsyncWrite, AsyncWriteExt};
 use tokio::sync::mpsc::{self, Receiver};
-use tokio::task::{JoinHandle, JoinSet};
-use tokio::try_join;
+use tokio::task::JoinSet;
 
 type WriterType = AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>;
 type SerializerType = Arc<dyn BatchSerializer>;
@@ -51,14 +52,14 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
     mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
 ) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> {
     let (tx, mut rx) =
-        mpsc::channel::<JoinHandle<Result<(usize, Bytes), 
DataFusionError>>>(100);
-    let serialize_task = tokio::spawn(async move {
+        mpsc::channel::<SpawnedTask<Result<(usize, Bytes), 
DataFusionError>>>(100);
+    let serialize_task = SpawnedTask::spawn(async move {
         // Some serializers (like CSV) handle the first batch differently than
         // subsequent batches, so we track that here.
         let mut initial = true;
         while let Some(batch) = data_rx.recv().await {
             let serializer_clone = serializer.clone();
-            let handle = tokio::spawn(async move {
+            let task = SpawnedTask::spawn(async move {
                 let num_rows = batch.num_rows();
                 let bytes = serializer_clone.serialize(batch, initial)?;
                 Ok((num_rows, bytes))
@@ -66,7 +67,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
             if initial {
                 initial = false;
             }
-            tx.send(handle).await.map_err(|_| {
+            tx.send(task).await.map_err(|_| {
                 internal_datafusion_err!("Unknown error writing to object 
store")
             })?;
         }
@@ -74,8 +75,8 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
     });
 
     let mut row_count = 0;
-    while let Some(handle) = rx.recv().await {
-        match handle.await {
+    while let Some(task) = rx.recv().await {
+        match task.join().await {
             Ok(Ok((cnt, bytes))) => {
                 match writer.write_all(&bytes).await {
                     Ok(_) => (),
@@ -106,7 +107,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
         }
     }
 
-    match serialize_task.await {
+    match serialize_task.join().await {
         Ok(Ok(_)) => (),
         Ok(Err(e)) => return Err((writer, e)),
         Err(_) => {
@@ -115,7 +116,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
                 internal_datafusion_err!("Unknown error writing to object 
store"),
             ))
         }
-    };
+    }
     Ok((writer, row_count as u64))
 }
 
@@ -241,9 +242,9 @@ pub(crate) async fn stateless_multipart_put(
         .execution
         .max_buffered_batches_per_output_file;
 
-    let (tx_file_bundle, rx_file_bundle) = 
tokio::sync::mpsc::channel(rb_buffer_size / 2);
+    let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
     let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
-    let write_coordinater_task = tokio::spawn(async move {
+    let write_coordinator_task = SpawnedTask::spawn(async move {
         stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
     });
     while let Some((location, rb_stream)) = file_stream_rx.recv().await {
@@ -260,10 +261,10 @@ pub(crate) async fn stateless_multipart_put(
             })?;
     }
 
-    // Signal to the write coordinater that no more files are coming
+    // Signal to the write coordinator that no more files are coming
     drop(tx_file_bundle);
 
-    match try_join!(write_coordinater_task, demux_task) {
+    match try_join!(write_coordinator_task.join(), demux_task.join()) {
         Ok((r1, r2)) => {
             r1?;
             r2?;
diff --git a/datafusion/core/src/datasource/stream.rs 
b/datafusion/core/src/datasource/stream.rs
index 830cd7a07e..6dc59e4a5c 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -29,12 +29,11 @@ use arrow_array::{RecordBatch, RecordBatchReader, 
RecordBatchWriter};
 use arrow_schema::SchemaRef;
 use async_trait::async_trait;
 use futures::StreamExt;
-use tokio::task::spawn_blocking;
 
 use datafusion_common::{plan_err, Constraints, DataFusionError, Result};
 use datafusion_execution::{SendableRecordBatchStream, TaskContext};
 use datafusion_expr::{CreateExternalTable, Expr, TableType};
-use datafusion_physical_plan::common::AbortOnDropSingle;
+use datafusion_physical_plan::common::SpawnedTask;
 use datafusion_physical_plan::insert::{DataSink, FileSinkExec};
 use datafusion_physical_plan::metrics::MetricsSet;
 use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
@@ -344,7 +343,7 @@ impl DataSink for StreamWrite {
         let config = self.0.clone();
         let (sender, mut receiver) = 
tokio::sync::mpsc::channel::<RecordBatch>(2);
         // Note: FIFO Files support poll so this could use AsyncFd
-        let write = AbortOnDropSingle::new(spawn_blocking(move || {
+        let write_task = SpawnedTask::spawn_blocking(move || {
             let mut count = 0_u64;
             let mut writer = config.writer()?;
             while let Some(batch) = receiver.blocking_recv() {
@@ -352,7 +351,7 @@ impl DataSink for StreamWrite {
                 writer.write(&batch)?;
             }
             Ok(count)
-        }));
+        });
 
         while let Some(b) = data.next().await.transpose()? {
             if sender.send(b).await.is_err() {
@@ -360,6 +359,6 @@ impl DataSink for StreamWrite {
             }
         }
         drop(sender);
-        write.await.unwrap()
+        write_task.join().await.unwrap()
     }
 }
diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index ffc4a4f717..453a00a1a5 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2288,6 +2288,7 @@ mod tests {
     }
 
     #[tokio::test]
+    #[allow(clippy::disallowed_methods)]
     async fn send_context_to_threads() -> Result<()> {
         // ensure SessionContexts can be used in a multi-threaded
         // environment. Usecase is for concurrent planing.
diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs
index 93c7f73680..c9ad95a3a0 100644
--- a/datafusion/core/tests/fifo.rs
+++ b/datafusion/core/tests/fifo.rs
@@ -103,6 +103,7 @@ mod unix_test {
         let broken_pipe_timeout = Duration::from_secs(10);
         let sa = file_path.clone();
         // Spawn a new thread to write to the FIFO file
+        #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
         spawn_blocking(move || {
             let file = OpenOptions::new().write(true).open(sa).unwrap();
             // Reference time to use when deciding to fail the test
@@ -357,6 +358,7 @@ mod unix_test {
             (sink_fifo_path.clone(), sink_fifo_path.display());
 
         // Spawn a new thread to read sink EXTERNAL TABLE.
+        #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
         tasks.push(spawn_blocking(move || {
             let file = File::open(sink_fifo_path_thread).unwrap();
             let schema = Arc::new(Schema::new(vec![
diff --git 
a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
index df6499e9b1..6c9c3359eb 100644
--- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
@@ -302,6 +302,7 @@ mod sp_repartition_fuzz_tests {
                 let mut handles = Vec::new();
 
                 for seed in seed_start..seed_end {
+                    #[allow(clippy::disallowed_methods)] // spawn allowed only 
in tests
                     let job = 
tokio::spawn(run_sort_preserving_repartition_test(
                         make_staggered_batches::<true>(n_row, n_distinct, seed 
as u64),
                         is_first_roundrobin,
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs 
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 609d26c9c2..1cab4d5c2f 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -123,6 +123,7 @@ async fn window_bounded_window_random_comparison() -> 
Result<()> {
         for i in 0..n {
             let idx = i % test_cases.len();
             let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone();
+            #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
             let job = tokio::spawn(run_window_test(
                 make_staggered_batches::<true>(1000, n_distinct, i as u64),
                 i as u64,
diff --git a/datafusion/physical-plan/src/common.rs 
b/datafusion/physical-plan/src/common.rs
index e83dc2525b..5172bc9b2a 100644
--- a/datafusion/physical-plan/src/common.rs
+++ b/datafusion/physical-plan/src/common.rs
@@ -21,7 +21,6 @@ use std::fs;
 use std::fs::{metadata, File};
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
-use std::task::{Context, Poll};
 
 use super::SendableRecordBatchStream;
 use crate::stream::RecordBatchReceiverStream;
@@ -39,8 +38,7 @@ use datafusion_physical_expr::{PhysicalExpr, 
PhysicalSortExpr};
 
 use futures::{Future, StreamExt, TryStreamExt};
 use parking_lot::Mutex;
-use pin_project_lite::pin_project;
-use tokio::task::JoinHandle;
+use tokio::task::{JoinError, JoinSet};
 
 /// [`MemoryReservation`] used across query execution streams
 pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
@@ -174,50 +172,43 @@ pub fn compute_record_batch_statistics(
     }
 }
 
-pin_project! {
-    /// Helper that aborts the given join handle on drop.
-    ///
-    /// Useful to kill background tasks when the consumer is dropped.
-    #[derive(Debug)]
-    pub struct AbortOnDropSingle<T>{
-        #[pin]
-        join_handle: JoinHandle<T>,
-    }
-
-    impl<T> PinnedDrop for AbortOnDropSingle<T> {
-        fn drop(this: Pin<&mut Self>) {
-            this.join_handle.abort();
-        }
-    }
+/// Helper that  provides a simple API to spawn a single task and join it.
+/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
+///
+/// Technically, it's just a wrapper of `JoinSet` (with size=1).
+#[derive(Debug)]
+pub struct SpawnedTask<R> {
+    inner: JoinSet<R>,
 }
 
-impl<T> AbortOnDropSingle<T> {
-    /// Create new abort helper from join handle.
-    pub fn new(join_handle: JoinHandle<T>) -> Self {
-        Self { join_handle }
+impl<R: 'static> SpawnedTask<R> {
+    pub fn spawn<T>(task: T) -> Self
+    where
+        T: Future<Output = R>,
+        T: Send + 'static,
+        R: Send,
+    {
+        let mut inner = JoinSet::new();
+        inner.spawn(task);
+        Self { inner }
     }
-}
 
-impl<T> Future for AbortOnDropSingle<T> {
-    type Output = Result<T, tokio::task::JoinError>;
-
-    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Self::Output> {
-        let this = self.project();
-        this.join_handle.poll(cx)
+    pub fn spawn_blocking<T>(task: T) -> Self
+    where
+        T: FnOnce() -> R,
+        T: Send + 'static,
+        R: Send,
+    {
+        let mut inner = JoinSet::new();
+        inner.spawn_blocking(task);
+        Self { inner }
     }
-}
-
-/// Helper that aborts the given join handles on drop.
-///
-/// Useful to kill background tasks when the consumer is dropped.
-#[derive(Debug)]
-pub struct AbortOnDropMany<T>(pub Vec<JoinHandle<T>>);
 
-impl<T> Drop for AbortOnDropMany<T> {
-    fn drop(&mut self) {
-        for join_handle in &self.0 {
-            join_handle.abort();
-        }
+    pub async fn join(mut self) -> Result<R, JoinError> {
+        self.inner
+            .join_next()
+            .await
+            .expect("`SpawnedTask` instance always contains exactly 1 task")
     }
 }
 
diff --git a/datafusion/physical-plan/src/lib.rs 
b/datafusion/physical-plan/src/lib.rs
index 1c4a6ac0ec..562e42a7da 100644
--- a/datafusion/physical-plan/src/lib.rs
+++ b/datafusion/physical-plan/src/lib.rs
@@ -298,14 +298,14 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
     /// "abort" such tasks, they may continue to consume resources even after
     /// the plan is dropped, generating intermediate results that are never
     /// used.
+    /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`].
     ///
-    /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and
-    /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all
-    /// background tasks are cancelled.
+    /// For more details see [`SpawnedTask`], [`JoinSet`] and 
[`RecordBatchReceiverStreamBuilder`]
+    /// for structures to help ensure all background tasks are cancelled.
     ///
     /// [`spawn`]: tokio::task::spawn
-    /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle
-    /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany
+    /// [`JoinSet`]: tokio::task::JoinSet
+    /// [`SpawnedTask`]: crate::common::SpawnedTask
     /// [`RecordBatchReceiverStreamBuilder`]: 
crate::stream::RecordBatchReceiverStreamBuilder
     ///
     /// # Implementation Examples
diff --git a/datafusion/physical-plan/src/repartition/mod.rs 
b/datafusion/physical-plan/src/repartition/mod.rs
index 07693f747f..a66a929796 100644
--- a/datafusion/physical-plan/src/repartition/mod.rs
+++ b/datafusion/physical-plan/src/repartition/mod.rs
@@ -32,21 +32,20 @@ use futures::{FutureExt, StreamExt};
 use hashbrown::HashMap;
 use log::trace;
 use parking_lot::Mutex;
-use tokio::task::JoinHandle;
 
 use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError, 
Result};
 use datafusion_execution::memory_pool::MemoryConsumer;
 use datafusion_execution::TaskContext;
 use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
 
-use crate::common::transpose;
+use crate::common::{transpose, SpawnedTask};
 use crate::hash_utils::create_hashes;
 use crate::metrics::BaselineMetrics;
 use crate::repartition::distributor_channels::{channels, 
partition_aware_channels};
 use crate::sorts::streaming_merge;
 use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics};
 
-use super::common::{AbortOnDropMany, AbortOnDropSingle, 
SharedMemoryReservation};
+use super::common::SharedMemoryReservation;
 use super::expressions::PhysicalSortExpr;
 use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
 use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream};
@@ -74,7 +73,7 @@ struct RepartitionExecState {
     >,
 
     /// Helper that ensures that that background job is killed once it is no 
longer needed.
-    abort_helper: Arc<AbortOnDropMany<()>>,
+    abort_helper: Arc<Vec<SpawnedTask<()>>>,
 }
 
 /// A utility that can be used to partition batches based on [`Partitioning`]
@@ -522,7 +521,7 @@ impl ExecutionPlan for RepartitionExec {
             }
 
             // launch one async task per *input* partition
-            let mut join_handles = Vec::with_capacity(num_input_partitions);
+            let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
             for i in 0..num_input_partitions {
                 let txs: HashMap<_, _> = state
                     .channels
@@ -534,28 +533,27 @@ impl ExecutionPlan for RepartitionExec {
 
                 let r_metrics = RepartitionMetrics::new(i, partition, 
&self.metrics);
 
-                let input_task: JoinHandle<Result<()>> =
-                    tokio::spawn(Self::pull_from_input(
-                        self.input.clone(),
-                        i,
-                        txs.clone(),
-                        self.partitioning.clone(),
-                        r_metrics,
-                        context.clone(),
-                    ));
+                let input_task = SpawnedTask::spawn(Self::pull_from_input(
+                    self.input.clone(),
+                    i,
+                    txs.clone(),
+                    self.partitioning.clone(),
+                    r_metrics,
+                    context.clone(),
+                ));
 
                 // In a separate task, wait for each input to be done
                 // (and pass along any errors, including panic!s)
-                let join_handle = tokio::spawn(Self::wait_for_task(
-                    AbortOnDropSingle::new(input_task),
+                let wait_for_task = SpawnedTask::spawn(Self::wait_for_task(
+                    input_task,
                     txs.into_iter()
                         .map(|(partition, (tx, _reservation))| (partition, tx))
                         .collect(),
                 ));
-                join_handles.push(join_handle);
+                spawned_tasks.push(wait_for_task);
             }
 
-            state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
+            state.abort_helper = Arc::new(spawned_tasks)
         }
 
         trace!(
@@ -638,7 +636,7 @@ impl RepartitionExec {
             partitioning,
             state: Arc::new(Mutex::new(RepartitionExecState {
                 channels: HashMap::new(),
-                abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
+                abort_helper: Arc::new(Vec::new()),
             })),
             metrics: ExecutionPlanMetricsSet::new(),
             preserve_order: false,
@@ -759,12 +757,13 @@ impl RepartitionExec {
     /// complete. Upon error, propagates the errors to all output tx
     /// channels.
     async fn wait_for_task(
-        input_task: AbortOnDropSingle<Result<()>>,
+        input_task: SpawnedTask<Result<()>>,
         txs: HashMap<usize, DistributionSender<MaybeBatch>>,
     ) {
         // wait for completion, and propagate error
         // note we ignore errors on send (.ok) as that means the receiver has 
already shutdown.
-        match input_task.await {
+
+        match input_task.join().await {
             // Error in joining task
             Err(e) => {
                 let e = Arc::new(e);
@@ -813,7 +812,7 @@ struct RepartitionStream {
 
     /// Handle to ensure background tasks are killed when no longer needed.
     #[allow(dead_code)]
-    drop_helper: Arc<AbortOnDropMany<()>>,
+    drop_helper: Arc<Vec<SpawnedTask<()>>>,
 
     /// Memory reservation.
     reservation: SharedMemoryReservation,
@@ -877,7 +876,7 @@ struct PerPartitionStream {
 
     /// Handle to ensure background tasks are killed when no longer needed.
     #[allow(dead_code)]
-    drop_helper: Arc<AbortOnDropMany<()>>,
+    drop_helper: Arc<Vec<SpawnedTask<()>>>,
 
     /// Memory reservation.
     reservation: SharedMemoryReservation,
@@ -1056,6 +1055,7 @@ mod tests {
     }
 
     #[tokio::test]
+    #[allow(clippy::disallowed_methods)]
     async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
         let join_handle: JoinHandle<Result<Vec<Vec<RecordBatch>>>> =
             tokio::spawn(async move {
diff --git a/datafusion/physical-plan/src/sorts/sort.rs 
b/datafusion/physical-plan/src/sorts/sort.rs
index 2d8237011f..84bf3ec415 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -27,7 +27,7 @@ use std::io::BufReader;
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
 
-use crate::common::{spawn_buffered, IPCWriter};
+use crate::common::{spawn_buffered, IPCWriter, SpawnedTask};
 use crate::expressions::PhysicalSortExpr;
 use crate::metrics::{
     BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
@@ -56,7 +56,6 @@ use datafusion_physical_expr::EquivalenceProperties;
 use futures::{StreamExt, TryStreamExt};
 use log::{debug, error, trace};
 use tokio::sync::mpsc::Sender;
-use tokio::task;
 
 struct ExternalSorterMetrics {
     /// metrics
@@ -604,8 +603,8 @@ async fn spill_sorted_batches(
     schema: SchemaRef,
 ) -> Result<()> {
     let path: PathBuf = path.into();
-    let handle = task::spawn_blocking(move || write_sorted(batches, path, 
schema));
-    match handle.await {
+    let task = SpawnedTask::spawn_blocking(move || write_sorted(batches, path, 
schema));
+    match task.join().await {
         Ok(r) => r,
         Err(e) => exec_err!("Error occurred while spilling {e}"),
     }
diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs 
b/datafusion/sqllogictest/bin/sqllogictests.rs
index ffae144eae..41c33deec6 100644
--- a/datafusion/sqllogictest/bin/sqllogictests.rs
+++ b/datafusion/sqllogictest/bin/sqllogictests.rs
@@ -88,6 +88,7 @@ async fn run_tests() -> Result<()> {
     // modifying shared state like `/tmp/`)
     let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?)
         .map(|test_file| {
+            #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
             tokio::task::spawn(async move {
                 println!("Running {:?}", test_file.relative_path);
                 if options.complete {
diff --git a/dev/release/rat_exclude_files.txt 
b/dev/release/rat_exclude_files.txt
index f99d6e15e8..ce5635b6da 100644
--- a/dev/release/rat_exclude_files.txt
+++ b/dev/release/rat_exclude_files.txt
@@ -136,4 +136,5 @@ datafusion/proto/src/generated/prost.rs
 .github/ISSUE_TEMPLATE/feature_request.yml
 .github/workflows/docs.yaml
 **/node_modules/*
-datafusion/wasmtest/pkg/*
\ No newline at end of file
+datafusion/wasmtest/pkg/*
+clippy.toml

Reply via email to