This is an automated email from the ASF dual-hosted git repository.
mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 02a470f606 Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet
(#6750)
02a470f606 is described below
commit 02a470f6061cce8ee8e57f7af8a6a0e0ddc1571b
Author: Armin Primadi <[email protected]>
AuthorDate: Tue Jul 4 20:00:33 2023 +0700
Replace AbortOnDrop / AbortDropOnMany with tokio JoinSet (#6750)
* Use JoinSet in MemTable
* Fix error handling
* Refactor AbortOnDropSingle in csv physical plan
* Fix csv write physical plan error propagation
* Refactor json write physical plan to use JoinSet
* Refactor parquet write physical plan to use JoinSet
* Refactor collect_partitioned to use JoinSet
* Refactor pull_from_input method to make it easier to read
* Fix typo
---
datafusion/core/src/datasource/memory.rs | 39 +++++++++++---------
.../core/src/datasource/physical_plan/csv.rs | 32 +++++++++-------
.../core/src/datasource/physical_plan/json.rs | 32 +++++++++-------
.../core/src/datasource/physical_plan/parquet.rs | 43 ++++++++++++----------
datafusion/core/src/physical_plan/mod.rs | 41 +++++++++++++++------
.../core/src/physical_plan/repartition/mod.rs | 33 ++++++++---------
6 files changed, 127 insertions(+), 93 deletions(-)
diff --git a/datafusion/core/src/datasource/memory.rs
b/datafusion/core/src/datasource/memory.rs
index 784aa2aff2..5398bb0903 100644
--- a/datafusion/core/src/datasource/memory.rs
+++ b/datafusion/core/src/datasource/memory.rs
@@ -29,12 +29,12 @@ use async_trait::async_trait;
use datafusion_common::SchemaExt;
use datafusion_execution::TaskContext;
use tokio::sync::RwLock;
+use tokio::task::JoinSet;
use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::SessionState;
use crate::logical_expr::Expr;
-use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::insert::{DataSink, InsertExec};
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{common, SendableRecordBatchStream};
@@ -89,26 +89,31 @@ impl MemTable {
let exec = t.scan(state, None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();
- let tasks = (0..partition_count)
- .map(|part_i| {
- let task = state.task_ctx();
- let exec = exec.clone();
- let task = tokio::spawn(async move {
- let stream = exec.execute(part_i, task)?;
- common::collect(stream).await
- });
-
- AbortOnDropSingle::new(task)
- })
- // this collect *is needed* so that the join below can
- // switch between tasks
- .collect::<Vec<_>>();
+ let mut join_set = JoinSet::new();
+
+ for part_idx in 0..partition_count {
+ let task = state.task_ctx();
+ let exec = exec.clone();
+ join_set.spawn(async move {
+ let stream = exec.execute(part_idx, task)?;
+ common::collect(stream).await
+ });
+ }
let mut data: Vec<Vec<RecordBatch>> =
Vec::with_capacity(exec.output_partitioning().partition_count());
- for result in futures::future::join_all(tasks).await {
- data.push(result.map_err(|e|
DataFusionError::External(Box::new(e)))??)
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok(res) => data.push(res?),
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ unreachable!();
+ }
+ }
+ }
}
let exec = MemoryExec::try_new(&data, schema.clone(), None)?;
diff --git a/datafusion/core/src/datasource/physical_plan/csv.rs
b/datafusion/core/src/datasource/physical_plan/csv.rs
index 027bd1945b..eba51615cd 100644
--- a/datafusion/core/src/datasource/physical_plan/csv.rs
+++ b/datafusion/core/src/datasource/physical_plan/csv.rs
@@ -23,7 +23,6 @@ use crate::datasource::physical_plan::file_stream::{
};
use crate::datasource::physical_plan::FileMeta;
use crate::error::{DataFusionError, Result};
-use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
@@ -46,7 +45,7 @@ use std::fs;
use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
-use tokio::task::{self, JoinHandle};
+use tokio::task::JoinSet;
/// Execution plan for scanning a CSV file
#[derive(Debug, Clone)]
@@ -331,7 +330,7 @@ pub async fn plan_to_csv(
)));
}
- let mut tasks = vec![];
+ let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.csv");
@@ -340,22 +339,29 @@ pub async fn plan_to_csv(
let mut writer = csv::Writer::new(file);
let stream = plan.execute(i, task_ctx.clone())?;
- let handle: JoinHandle<Result<()>> = task::spawn(async move {
- stream
+ join_set.spawn(async move {
+ let result: Result<()> = stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
- .map_err(DataFusionError::from)
+ .map_err(DataFusionError::from);
+ result
});
- tasks.push(AbortOnDropSingle::new(handle));
}
- futures::future::join_all(tasks)
- .await
- .into_iter()
- .try_for_each(|result| {
- result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
- })?;
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok(res) => res?, // propagate DataFusion error
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ unreachable!();
+ }
+ }
+ }
+ }
+
Ok(())
}
diff --git a/datafusion/core/src/datasource/physical_plan/json.rs
b/datafusion/core/src/datasource/physical_plan/json.rs
index b736fd7839..64f7077660 100644
--- a/datafusion/core/src/datasource/physical_plan/json.rs
+++ b/datafusion/core/src/datasource/physical_plan/json.rs
@@ -22,7 +22,6 @@ use crate::datasource::physical_plan::file_stream::{
};
use crate::datasource::physical_plan::FileMeta;
use crate::error::{DataFusionError, Result};
-use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
@@ -44,7 +43,7 @@ use std::io::BufReader;
use std::path::Path;
use std::sync::Arc;
use std::task::Poll;
-use tokio::task::{self, JoinHandle};
+use tokio::task::JoinSet;
use super::FileScanConfig;
@@ -266,7 +265,7 @@ pub async fn plan_to_json(
)));
}
- let mut tasks = vec![];
+ let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.json");
@@ -274,22 +273,29 @@ pub async fn plan_to_json(
let file = fs::File::create(path)?;
let mut writer = json::LineDelimitedWriter::new(file);
let stream = plan.execute(i, task_ctx.clone())?;
- let handle: JoinHandle<Result<()>> = task::spawn(async move {
- stream
+ join_set.spawn(async move {
+ let result: Result<()> = stream
.map(|batch| writer.write(&batch?))
.try_collect()
.await
- .map_err(DataFusionError::from)
+ .map_err(DataFusionError::from);
+ result
});
- tasks.push(AbortOnDropSingle::new(handle));
}
- futures::future::join_all(tasks)
- .await
- .into_iter()
- .try_for_each(|result| {
- result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
- })?;
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok(res) => res?, // propagate DataFusion error
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ unreachable!();
+ }
+ }
+ }
+ }
+
Ok(())
}
diff --git a/datafusion/core/src/datasource/physical_plan/parquet.rs
b/datafusion/core/src/datasource/physical_plan/parquet.rs
index f538255bc2..96e5ce9fa0 100644
--- a/datafusion/core/src/datasource/physical_plan/parquet.rs
+++ b/datafusion/core/src/datasource/physical_plan/parquet.rs
@@ -31,7 +31,6 @@ use crate::{
execution::context::TaskContext,
physical_optimizer::pruning::PruningPredicate,
physical_plan::{
- common::AbortOnDropSingle,
metrics::{ExecutionPlanMetricsSet, MetricBuilder, MetricsSet},
ordering_equivalence_properties_helper, DisplayFormatType,
ExecutionPlan,
Partitioning, SendableRecordBatchStream, Statistics,
@@ -64,6 +63,7 @@ use parquet::arrow::{ArrowWriter,
ParquetRecordBatchStreamBuilder, ProjectionMas
use parquet::basic::{ConvertedType, LogicalType};
use parquet::file::{metadata::ParquetMetaData, properties::WriterProperties};
use parquet::schema::types::ColumnDescriptor;
+use tokio::task::JoinSet;
mod metrics;
pub mod page_filter;
@@ -701,7 +701,7 @@ pub async fn plan_to_parquet(
)));
}
- let mut tasks = vec![];
+ let mut join_set = JoinSet::new();
for i in 0..plan.output_partitioning().partition_count() {
let plan = plan.clone();
let filename = format!("part-{i}.parquet");
@@ -710,27 +710,30 @@ pub async fn plan_to_parquet(
let mut writer =
ArrowWriter::try_new(file, plan.schema(),
writer_properties.clone())?;
let stream = plan.execute(i, task_ctx.clone())?;
- let handle: tokio::task::JoinHandle<Result<()>> =
- tokio::task::spawn(async move {
- stream
- .map(|batch| {
-
writer.write(&batch?).map_err(DataFusionError::ParquetError)
- })
- .try_collect()
- .await
- .map_err(DataFusionError::from)?;
+ join_set.spawn(async move {
+ stream
+ .map(|batch|
writer.write(&batch?).map_err(DataFusionError::ParquetError))
+ .try_collect()
+ .await
+ .map_err(DataFusionError::from)?;
+
+ writer.close().map_err(DataFusionError::from).map(|_| ())
+ });
+ }
- writer.close().map_err(DataFusionError::from).map(|_| ())
- });
- tasks.push(AbortOnDropSingle::new(handle));
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok(res) => res?,
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ unreachable!();
+ }
+ }
+ }
}
- futures::future::join_all(tasks)
- .await
- .into_iter()
- .try_for_each(|result| {
- result.map_err(|e| DataFusionError::Execution(format!("{e}")))?
- })?;
Ok(())
}
diff --git a/datafusion/core/src/physical_plan/mod.rs
b/datafusion/core/src/physical_plan/mod.rs
index 5abecf6b16..7efd5a19ee 100644
--- a/datafusion/core/src/physical_plan/mod.rs
+++ b/datafusion/core/src/physical_plan/mod.rs
@@ -38,6 +38,7 @@ pub use display::{DefaultDisplay, DisplayAs,
DisplayFormatType, VerboseDisplay};
use futures::stream::{Stream, TryStreamExt};
use std::fmt;
use std::fmt::Debug;
+use tokio::task::JoinSet;
use datafusion_common::tree_node::Transformed;
use datafusion_common::DataFusionError;
@@ -445,20 +446,37 @@ pub async fn collect_partitioned(
) -> Result<Vec<Vec<RecordBatch>>> {
let streams = execute_stream_partitioned(plan, context)?;
+ let mut join_set = JoinSet::new();
// Execute the plan and collect the results into batches.
- let handles = streams
- .into_iter()
- .enumerate()
- .map(|(idx, stream)| async move {
- let handle = tokio::task::spawn(stream.try_collect());
- AbortOnDropSingle::new(handle).await.map_err(|e| {
- DataFusionError::Execution(format!(
- "collect_partitioned partition {idx} panicked: {e}"
- ))
- })?
+ streams.into_iter().enumerate().for_each(|(idx, stream)| {
+ join_set.spawn(async move {
+ let result: Result<Vec<RecordBatch>> = stream.try_collect().await;
+ (idx, result)
});
+ });
+
+ let mut batches = vec![];
+ // Note that currently this doesn't identify the thread that panicked
+ //
+ // TODO: Replace with
[join_next_with_id](https://docs.rs/tokio/latest/tokio/task/struct.JoinSet.html#method.join_next_with_id
+ // once it is stable
+ while let Some(result) = join_set.join_next().await {
+ match result {
+ Ok((idx, res)) => batches.push((idx, res?)),
+ Err(e) => {
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ unreachable!();
+ }
+ }
+ }
+ }
+
+ batches.sort_by_key(|(idx, _)| *idx);
+ let batches = batches.into_iter().map(|(_, batch)| batch).collect();
- futures::future::try_join_all(handles).await
+ Ok(batches)
}
/// Execute the [ExecutionPlan] and return a vec with one stream per output
partition
@@ -713,7 +731,6 @@ pub mod unnest;
pub mod values;
pub mod windows;
-use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::repartition::RepartitionExec;
use
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec;
use datafusion_execution::TaskContext;
diff --git a/datafusion/core/src/physical_plan/repartition/mod.rs
b/datafusion/core/src/physical_plan/repartition/mod.rs
index 85225eb471..3c689e97ab 100644
--- a/datafusion/core/src/physical_plan/repartition/mod.rs
+++ b/datafusion/core/src/physical_plan/repartition/mod.rs
@@ -263,7 +263,7 @@ struct RepartitionMetrics {
/// Time in nanos to execute child operator and fetch batches
fetch_time: metrics::Time,
/// Time in nanos to perform repartitioning
- repart_time: metrics::Time,
+ repartition_time: metrics::Time,
/// Time in nanos for sending resulting batches to channels
send_time: metrics::Time,
}
@@ -293,7 +293,7 @@ impl RepartitionMetrics {
Self {
fetch_time,
- repart_time,
+ repartition_time: repart_time,
send_time,
}
}
@@ -407,7 +407,7 @@ impl ExecutionPlan for RepartitionExec {
// note we use a custom channel that ensures there is always
data for each receiver
// but limits the amount of buffering if required.
let (txs, rxs) = channels(num_output_partitions);
- // Clone sender for ech input partitions
+ // Clone sender for each input partitions
let txs = txs
.into_iter()
.map(|item| vec![item; num_input_partitions])
@@ -565,34 +565,31 @@ impl RepartitionExec {
/// Pulls data from the specified input plan, feeding it to the
/// output partitions based on the desired partitioning
///
- /// i is the input partition index
- ///
/// txs hold the output sending channels for each output partition
async fn pull_from_input(
input: Arc<dyn ExecutionPlan>,
- i: usize,
- mut txs: HashMap<
+ partition: usize,
+ mut output_channels: HashMap<
usize,
(DistributionSender<MaybeBatch>, SharedMemoryReservation),
>,
partitioning: Partitioning,
- r_metrics: RepartitionMetrics,
+ metrics: RepartitionMetrics,
context: Arc<TaskContext>,
) -> Result<()> {
let mut partitioner =
- BatchPartitioner::try_new(partitioning,
r_metrics.repart_time.clone())?;
+ BatchPartitioner::try_new(partitioning,
metrics.repartition_time.clone())?;
// execute the child operator
- let timer = r_metrics.fetch_time.timer();
- let mut stream = input.execute(i, context)?;
+ let timer = metrics.fetch_time.timer();
+ let mut stream = input.execute(partition, context)?;
timer.done();
- // While there are still outputs to send to, keep
- // pulling inputs
+ // While there are still outputs to send to, keep pulling inputs
let mut batches_until_yield = partitioner.num_partitions();
- while !txs.is_empty() {
+ while !output_channels.is_empty() {
// fetch the next batch
- let timer = r_metrics.fetch_time.timer();
+ let timer = metrics.fetch_time.timer();
let result = stream.next().await;
timer.done();
@@ -606,15 +603,15 @@ impl RepartitionExec {
let (partition, batch) = res?;
let size = batch.get_array_memory_size();
- let timer = r_metrics.send_time.timer();
+ let timer = metrics.send_time.timer();
// if there is still a receiver, send to it
- if let Some((tx, reservation)) = txs.get_mut(&partition) {
+ if let Some((tx, reservation)) =
output_channels.get_mut(&partition) {
reservation.lock().try_grow(size)?;
if tx.send(Some(Ok(batch))).await.is_err() {
// If the other end has hung up, it was an early
shutdown (e.g. LIMIT)
reservation.lock().shrink(size);
- txs.remove(&partition);
+ output_channels.remove(&partition);
}
}
timer.done();