This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fa2bb6c4e8 Extract ReceiverStreamBuilder (#7817)
fa2bb6c4e8 is described below

commit fa2bb6c4e80d80e3d1d26ce85f7c21232f036dd5
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Mon Oct 16 10:02:05 2023 +0100

    Extract ReceiverStreamBuilder (#7817)
    
    * Extract ReceiverStreamBuilder
    
    * Docs and format
    
    * Update datafusion/physical-plan/src/stream.rs
    
    * fmt
    
    * Undo changes to testing pin
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/physical-plan/src/stream.rs | 232 +++++++++++++++++++--------------
 1 file changed, 132 insertions(+), 100 deletions(-)

diff --git a/datafusion/physical-plan/src/stream.rs 
b/datafusion/physical-plan/src/stream.rs
index a3fb856c32..fdf32620ca 100644
--- a/datafusion/physical-plan/src/stream.rs
+++ b/datafusion/physical-plan/src/stream.rs
@@ -38,6 +38,124 @@ use tokio::task::JoinSet;
 use super::metrics::BaselineMetrics;
 use super::{ExecutionPlan, RecordBatchStream, SendableRecordBatchStream};
 
+/// Creates a stream from a collection of producing tasks, routing panics to 
the stream.
+///
+/// Note that this is similar to  [`ReceiverStream` from tokio-stream], with 
the differences being:
+///
+/// 1. Methods to bound and "detach"  tasks (`spawn()` and `spawn_blocking()`).
+///
+/// 2. Propagates panics, whereas the `tokio` version doesn't propagate panics 
to the receiver.
+///
+/// 3. Automatically cancels any outstanding tasks when the receiver stream is 
dropped.
+///
+/// [`ReceiverStream` from tokio-stream]: 
https://docs.rs/tokio-stream/latest/tokio_stream/wrappers/struct.ReceiverStream.html
+
+pub(crate) struct ReceiverStreamBuilder<O> {
+    tx: Sender<Result<O>>,
+    rx: Receiver<Result<O>>,
+    join_set: JoinSet<Result<()>>,
+}
+
+impl<O: Send + 'static> ReceiverStreamBuilder<O> {
+    /// create new channels with the specified buffer size
+    pub fn new(capacity: usize) -> Self {
+        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
+
+        Self {
+            tx,
+            rx,
+            join_set: JoinSet::new(),
+        }
+    }
+
+    /// Get a handle for sending data to the output
+    pub fn tx(&self) -> Sender<Result<O>> {
+        self.tx.clone()
+    }
+
+    /// Spawn task that will be aborted if this builder (or the stream
+    /// built from it) are dropped
+    pub fn spawn<F>(&mut self, task: F)
+    where
+        F: Future<Output = Result<()>>,
+        F: Send + 'static,
+    {
+        self.join_set.spawn(task);
+    }
+
+    /// Spawn a blocking task that will be aborted if this builder (or the 
stream
+    /// built from it) are dropped
+    ///
+    /// this is often used to spawn tasks that write to the sender
+    /// retrieved from `Self::tx`
+    pub fn spawn_blocking<F>(&mut self, f: F)
+    where
+        F: FnOnce() -> Result<()>,
+        F: Send + 'static,
+    {
+        self.join_set.spawn_blocking(f);
+    }
+
+    /// Create a stream of all data written to `tx`
+    pub fn build(self) -> BoxStream<'static, Result<O>> {
+        let Self {
+            tx,
+            rx,
+            mut join_set,
+        } = self;
+
+        // don't need tx
+        drop(tx);
+
+        // future that checks the result of the join set, and propagates panic 
if seen
+        let check = async move {
+            while let Some(result) = join_set.join_next().await {
+                match result {
+                    Ok(task_result) => {
+                        match task_result {
+                            // nothing to report
+                            Ok(_) => continue,
+                            // This means a blocking task error
+                            Err(e) => {
+                                return Some(exec_err!("Spawned Task error: 
{e}"));
+                            }
+                        }
+                    }
+                    // This means a tokio task error, likely a panic
+                    Err(e) => {
+                        if e.is_panic() {
+                            // resume on the main thread
+                            std::panic::resume_unwind(e.into_panic());
+                        } else {
+                            // This should only occur if the task is
+                            // cancelled, which would only occur if
+                            // the JoinSet were aborted, which in turn
+                            // would imply that the receiver has been
+                            // dropped and this code is not running
+                            return Some(internal_err!("Non Panic Task error: 
{e}"));
+                        }
+                    }
+                }
+            }
+            None
+        };
+
+        let check_stream = futures::stream::once(check)
+            // unwrap Option / only return the error
+            .filter_map(|item| async move { item });
+
+        // Convert the receiver into a stream
+        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
+            let next_item = rx.recv().await;
+            next_item.map(|next_item| (next_item, rx))
+        });
+
+        // Merge the streams together so whichever is ready first
+        // produces the batch
+        futures::stream::select(rx_stream, check_stream).boxed()
+    }
+}
+
 /// Builder for [`RecordBatchReceiverStream`] that propagates errors
 /// and panic's correctly.
 ///
@@ -47,28 +165,22 @@ use super::{ExecutionPlan, RecordBatchStream, 
SendableRecordBatchStream};
 ///
 /// This also handles propagating panic`s and canceling the tasks.
 pub struct RecordBatchReceiverStreamBuilder {
-    tx: Sender<Result<RecordBatch>>,
-    rx: Receiver<Result<RecordBatch>>,
     schema: SchemaRef,
-    join_set: JoinSet<Result<()>>,
+    inner: ReceiverStreamBuilder<RecordBatch>,
 }
 
 impl RecordBatchReceiverStreamBuilder {
     /// create new channels with the specified buffer size
     pub fn new(schema: SchemaRef, capacity: usize) -> Self {
-        let (tx, rx) = tokio::sync::mpsc::channel(capacity);
-
         Self {
-            tx,
-            rx,
             schema,
-            join_set: JoinSet::new(),
+            inner: ReceiverStreamBuilder::new(capacity),
         }
     }
 
-    /// Get a handle for sending [`RecordBatch`]es to the output
+    /// Get a handle for sending [`RecordBatch`] to the output
     pub fn tx(&self) -> Sender<Result<RecordBatch>> {
-        self.tx.clone()
+        self.inner.tx()
     }
 
     /// Spawn task that will be aborted if this builder (or the stream
@@ -81,7 +193,7 @@ impl RecordBatchReceiverStreamBuilder {
         F: Future<Output = Result<()>>,
         F: Send + 'static,
     {
-        self.join_set.spawn(task);
+        self.inner.spawn(task)
     }
 
     /// Spawn a blocking task that will be aborted if this builder (or the 
stream
@@ -94,7 +206,7 @@ impl RecordBatchReceiverStreamBuilder {
         F: FnOnce() -> Result<()>,
         F: Send + 'static,
     {
-        self.join_set.spawn_blocking(f);
+        self.inner.spawn_blocking(f)
     }
 
     /// runs the input_partition of the `input` ExecutionPlan on the
@@ -110,7 +222,7 @@ impl RecordBatchReceiverStreamBuilder {
     ) {
         let output = self.tx();
 
-        self.spawn(async move {
+        self.inner.spawn(async move {
             let mut stream = match input.execute(partition, context) {
                 Err(e) => {
                     // If send fails, the plan being torn down, there
@@ -155,80 +267,17 @@ impl RecordBatchReceiverStreamBuilder {
         });
     }
 
-    /// Create a stream of all `RecordBatch`es written to `tx`
+    /// Create a stream of all [`RecordBatch`] written to `tx`
     pub fn build(self) -> SendableRecordBatchStream {
-        let Self {
-            tx,
-            rx,
-            schema,
-            mut join_set,
-        } = self;
-
-        // don't need tx
-        drop(tx);
-
-        // future that checks the result of the join set, and propagates panic 
if seen
-        let check = async move {
-            while let Some(result) = join_set.join_next().await {
-                match result {
-                    Ok(task_result) => {
-                        match task_result {
-                            // nothing to report
-                            Ok(_) => continue,
-                            // This means a blocking task error
-                            Err(e) => {
-                                return Some(exec_err!("Spawned Task error: 
{e}"));
-                            }
-                        }
-                    }
-                    // This means a tokio task error, likely a panic
-                    Err(e) => {
-                        if e.is_panic() {
-                            // resume on the main thread
-                            std::panic::resume_unwind(e.into_panic());
-                        } else {
-                            // This should only occur if the task is
-                            // cancelled, which would only occur if
-                            // the JoinSet were aborted, which in turn
-                            // would imply that the receiver has been
-                            // dropped and this code is not running
-                            return Some(internal_err!("Non Panic Task error: 
{e}"));
-                        }
-                    }
-                }
-            }
-            None
-        };
-
-        let check_stream = futures::stream::once(check)
-            // unwrap Option / only return the error
-            .filter_map(|item| async move { item });
-
-        // Convert the receiver into a stream
-        let rx_stream = futures::stream::unfold(rx, |mut rx| async move {
-            let next_item = rx.recv().await;
-            next_item.map(|next_item| (next_item, rx))
-        });
-
-        // Merge the streams together so whichever is ready first
-        // produces the batch
-        let inner = futures::stream::select(rx_stream, check_stream).boxed();
-
-        Box::pin(RecordBatchReceiverStream { schema, inner })
+        Box::pin(RecordBatchStreamAdapter::new(
+            self.schema,
+            self.inner.build(),
+        ))
     }
 }
 
-/// A [`SendableRecordBatchStream`] that combines [`RecordBatch`]es from 
multiple inputs,
-/// on new tokio Tasks,  increasing the potential parallelism.
-///
-/// This structure also handles propagating panics and cancelling the
-/// underlying tasks correctly.
-///
-/// Use [`Self::builder`] to construct one.
-pub struct RecordBatchReceiverStream {
-    schema: SchemaRef,
-    inner: BoxStream<'static, Result<RecordBatch>>,
-}
+#[doc(hidden)]
+pub struct RecordBatchReceiverStream {}
 
 impl RecordBatchReceiverStream {
     /// Create a builder with an internal buffer of capacity batches.
@@ -240,23 +289,6 @@ impl RecordBatchReceiverStream {
     }
 }
 
-impl Stream for RecordBatchReceiverStream {
-    type Item = Result<RecordBatch>;
-
-    fn poll_next(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Option<Self::Item>> {
-        self.inner.poll_next_unpin(cx)
-    }
-}
-
-impl RecordBatchStream for RecordBatchReceiverStream {
-    fn schema(&self) -> SchemaRef {
-        self.schema.clone()
-    }
-}
-
 pin_project! {
     /// Combines a [`Stream`] with a [`SchemaRef`] implementing
     /// [`RecordBatchStream`] for the combination

Reply via email to