This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new dc0d35a2 Add Interruptible Query Execution in Jupyter via 
KeyboardInterrupt Support (#1141)
dc0d35a2 is described below

commit dc0d35a2cdb47f96a04c98d5cf1e7e79b0ba1467
Author: kosiew <[email protected]>
AuthorDate: Mon Jun 16 19:53:40 2025 +0800

    Add Interruptible Query Execution in Jupyter via KeyboardInterrupt Support 
(#1141)
    
    * fix: enhance error handling in async wait_for_future function
    
    * feat: implement async execution for execution plans in PySessionContext
    
    * fix: improve error message for execution failures in PySessionContext
    
    * fix: enhance error handling and improve execution plan retrieval in 
PyDataFrame
    
    * fix: ensure 'static lifetime for futures in wait_for_future function
    
    * fix: handle potential errors when caching DataFrame and retrieving 
execution plan
    
    * fix: flatten batches in PyDataFrame to ensure proper schema conversion
    
    * fix: correct error handling in batch processing for schema conversion
    
    * fix: flatten nested structure in PyDataFrame to ensure proper RecordBatch 
iteration
    
    * fix: improve error handling in PyDataFrame stream execution
    
    * fix: add utility to get Tokio Runtime with time enabled and update 
wait_for_future to use it
    
    * fix: store result of converting RecordBatches to PyArrow for debugging
    
    * fix: handle error from wait_for_future in PyDataFrame collect method
    
    * fix: propagate error from wait_for_future in collect_partitioned method
    
    * fix: enable IO in Tokio runtime with time support
    
    * main  register_listing_table
    
    * Revert "main  register_listing_table"
    
    This reverts commit 52a5efe2001455a3ad881968d468e5c7538e1ced.
    
    * fix: propagate error correctly from wait_for_future in PySessionContext 
methods
    
    * fix: simplify error handling in PySessionContext by unwrapping 
wait_for_future result
    
    * test: add interruption handling test for long-running queries in 
DataFusion
    
    * test: move test_collect_interrupted to test_dataframe.py
    
    * fix: add const for interval in wait_for_future utility
    
    * fix: use get_tokio_runtime instead of the custom  get_runtime
    
    * Revert "fix: use get_tokio_runtime instead of the custom  get_runtime"
    
    This reverts commit ca2d89289d0a702bbb38f34e88fb78ad61d20647.
    
    * fix: use get_tokio_runtime instead of the custom  get_runtime
    
    * .
    
    * Revert "."
    
    This reverts commit b8ce3e446b74aac7a76f1cc8ce6501b453d4f13c.
    
    * fix: improve query interruption handling in test_collect_interrupted
    
    * fix: ensure proper handling of query interruption in 
test_collect_interrupted
    
    * fix: improve error handling in database table retrieval
    
    * refactor: add helper for async move
    
    * Revert "refactor: add helper for async move"
    
    This reverts commit faabf6dd90ac505934e7cb6dc3b69fddbe89e661.
    
    * move py_err_to_datafusion_err to errors.rs
    
    * add create_csv_read_options
    
    * fix
    
    * create_csv_read_options -> PyDataFusionResult
    
    * revert to before create_csv_read_options
    
    * refactor: simplify file compression type parsing in PySessionContext
    
    * fix: parse_compression_type once only
    
    * add create_ndjson_read_options
    
    * refactor comment for clarity in wait_for_future function
    
    * refactor wait_for_future to avoid spawn
    
    * remove unused py_err_to_datafusion_err function
    
    * add comment to clarify error handling in next method of 
PyRecordBatchStream
    
    * handle error from wait_for_future in PySubstraitSerializer
    
    * clarify comment on future pinning in wait_for_future function
    
    * refactor wait_for_future to use Duration for signal check interval
    
    * handle error from wait_for_future in count method of PyDataFrame
    
    * fix ruff errors
    
    * fix clippy errors
    
    * remove unused get_and_enter_tokio_runtime function and simplify 
wait_for_future
    
    * Refactor async handling in PySessionContext and PyDataFrame
    
    - Simplified async handling by removing unnecessary cloning of strings and 
context in various methods.
    - Streamlined the use of `wait_for_future` to directly handle futures 
without intermediate variables.
    - Improved error handling by directly propagating results from futures.
    - Enhanced readability by reducing boilerplate code in methods related to 
reading and writing data.
    - Updated the `wait_for_future` function to improve signal checking and 
future handling.
    
    * Organize imports in utils.rs for improved readability
    
    * map_err instead of panic
    
    * Fix error handling in async stream execution for PySessionContext and 
PyDataFrame
---
 python/tests/test_dataframe.py | 121 +++++++++++++++++++++++++++++++++++++++++
 src/catalog.rs                 |   2 +-
 src/context.rs                 |  52 +++++++++++-------
 src/dataframe.rs               |  43 +++++++--------
 src/record_batch.rs            |   2 +-
 src/substrait.rs               |  11 ++--
 src/utils.rs                   |  46 +++++++++++-----
 7 files changed, 211 insertions(+), 66 deletions(-)

diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index dd5f962b..64220ce9 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -14,9 +14,12 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import ctypes
 import datetime
 import os
 import re
+import threading
+import time
 from typing import Any
 
 import pyarrow as pa
@@ -2060,3 +2063,121 @@ def test_fill_null_all_null_column(ctx):
     # Check that all nulls were filled
     result = filled_df.collect()[0]
     assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
+
+
+def test_collect_interrupted():
+    """Test that a long-running query can be interrupted with Ctrl-C.
+
+    This test simulates a Ctrl-C keyboard interrupt by raising a 
KeyboardInterrupt
+    exception in the main thread during a long-running query execution.
+    """
+    # Create a context and a DataFrame with a query that will run for a while
+    ctx = SessionContext()
+
+    # Create a recursive computation that will run for some time
+    batches = []
+    for i in range(10):
+        batch = pa.RecordBatch.from_arrays(
+            [
+                pa.array(list(range(i * 1000, (i + 1) * 1000))),
+                pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 
1000)]),
+            ],
+            names=["a", "b"],
+        )
+        batches.append(batch)
+
+    # Register tables
+    ctx.register_record_batches("t1", [batches])
+    ctx.register_record_batches("t2", [batches])
+
+    # Create a large join operation that will take time to process
+    df = ctx.sql("""
+        WITH t1_expanded AS (
+            SELECT
+                a,
+                b,
+                CAST(a AS DOUBLE) / 1.5 AS c,
+                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
+            FROM t1
+            CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
+        ),
+        t2_expanded AS (
+            SELECT
+                a,
+                b,
+                CAST(a AS DOUBLE) * 2.5 AS e,
+                CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
+            FROM t2
+            CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
+        )
+        SELECT
+            t1.a, t1.b, t1.c, t1.d,
+            t2.a AS a2, t2.b AS b2, t2.e, t2.f
+        FROM t1_expanded t1
+        JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
+        WHERE t1.a > 100 AND t2.a > 100
+    """)
+
+    # Flag to track if the query was interrupted
+    interrupted = False
+    interrupt_error = None
+    main_thread = threading.main_thread()
+
+    # Shared flag to indicate query execution has started
+    query_started = threading.Event()
+    max_wait_time = 5.0  # Maximum wait time in seconds
+
+    # This function will be run in a separate thread and will raise
+    # KeyboardInterrupt in the main thread
+    def trigger_interrupt():
+        """Poll for query start, then raise KeyboardInterrupt in the main 
thread"""
+        # Poll for query to start with small sleep intervals
+        start_time = time.time()
+        while not query_started.is_set():
+            time.sleep(0.1)  # Small sleep between checks
+            if time.time() - start_time > max_wait_time:
+                msg = f"Query did not start within {max_wait_time} seconds"
+                raise RuntimeError(msg)
+
+        # Check if thread ID is available
+        thread_id = main_thread.ident
+        if thread_id is None:
+            msg = "Cannot get main thread ID"
+            raise RuntimeError(msg)
+
+        # Use ctypes to raise exception in main thread
+        exception = ctypes.py_object(KeyboardInterrupt)
+        res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
+            ctypes.c_long(thread_id), exception
+        )
+        if res != 1:
+            # If res is 0, the thread ID was invalid
+            # If res > 1, we modified multiple threads
+            ctypes.pythonapi.PyThreadState_SetAsyncExc(
+                ctypes.c_long(thread_id), ctypes.py_object(0)
+            )
+            msg = "Failed to raise KeyboardInterrupt in main thread"
+            raise RuntimeError(msg)
+
+    # Start a thread to trigger the interrupt
+    interrupt_thread = threading.Thread(target=trigger_interrupt)
+    # we mark as daemon so the test process can exit even if this thread 
doesn't finish
+    interrupt_thread.daemon = True
+    interrupt_thread.start()
+
+    # Execute the query and expect it to be interrupted
+    try:
+        # Signal that we're about to start the query
+        query_started.set()
+        df.collect()
+    except KeyboardInterrupt:
+        interrupted = True
+    except Exception as e:
+        interrupt_error = e
+
+    # Assert that the query was interrupted properly
+    if not interrupted:
+        pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
+
+    # Make sure the interrupt thread has finished
+    interrupt_thread.join(timeout=1.0)
diff --git a/src/catalog.rs b/src/catalog.rs
index 1e189a5a..83f8d08c 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -97,7 +97,7 @@ impl PyDatabase {
     }
 
     fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
-        if let Some(table) = wait_for_future(py, self.database.table(name))? {
+        if let Some(table) = wait_for_future(py, self.database.table(name))?? {
             Ok(PyTable::new(table))
         } else {
             Err(PyDataFusionError::Common(format!(
diff --git a/src/context.rs b/src/context.rs
index cc3d8e8e..b0af566e 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
 use crate::catalog::{PyCatalog, PyTable};
 use crate::dataframe::PyDataFrame;
 use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, PyDataFusionResult};
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
 use crate::expr::sort_expr::PySortExpr;
 use crate::physical_plan::PyExecutionPlan;
 use crate::record_batch::PyRecordBatchStream;
@@ -375,7 +375,7 @@ impl PySessionContext {
             None => {
                 let state = self.ctx.state();
                 let schema = options.infer_schema(&state, &table_path);
-                wait_for_future(py, schema)?
+                wait_for_future(py, schema)??
             }
         };
         let config = ListingTableConfig::new(table_path)
@@ -400,7 +400,7 @@ impl PySessionContext {
     /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
     pub fn sql(&mut self, query: &str, py: Python) -> 
PyDataFusionResult<PyDataFrame> {
         let result = self.ctx.sql(query);
-        let df = wait_for_future(py, result)?;
+        let df = wait_for_future(py, result)??;
         Ok(PyDataFrame::new(df))
     }
 
@@ -417,7 +417,7 @@ impl PySessionContext {
             SQLOptions::new()
         };
         let result = self.ctx.sql_with_options(query, options);
-        let df = wait_for_future(py, result)?;
+        let df = wait_for_future(py, result)??;
         Ok(PyDataFrame::new(df))
     }
 
@@ -451,7 +451,7 @@ impl PySessionContext {
 
         self.ctx.register_table(&*table_name, Arc::new(table))?;
 
-        let table = wait_for_future(py, self._table(&table_name))?;
+        let table = wait_for_future(py, self._table(&table_name))??;
 
         let df = PyDataFrame::new(table);
         Ok(df)
@@ -650,7 +650,7 @@ impl PySessionContext {
             .collect();
 
         let result = self.ctx.register_parquet(name, path, options);
-        wait_for_future(py, result)?;
+        wait_for_future(py, result)??;
         Ok(())
     }
 
@@ -693,11 +693,11 @@ impl PySessionContext {
         if path.is_instance_of::<PyList>() {
             let paths = path.extract::<Vec<String>>()?;
             let result = self.register_csv_from_multiple_paths(name, paths, 
options);
-            wait_for_future(py, result)?;
+            wait_for_future(py, result)??;
         } else {
             let path = path.extract::<String>()?;
             let result = self.ctx.register_csv(name, &path, options);
-            wait_for_future(py, result)?;
+            wait_for_future(py, result)??;
         }
 
         Ok(())
@@ -734,7 +734,7 @@ impl PySessionContext {
         options.schema = schema.as_ref().map(|x| &x.0);
 
         let result = self.ctx.register_json(name, path, options);
-        wait_for_future(py, result)?;
+        wait_for_future(py, result)??;
 
         Ok(())
     }
@@ -764,7 +764,7 @@ impl PySessionContext {
         options.schema = schema.as_ref().map(|x| &x.0);
 
         let result = self.ctx.register_avro(name, path, options);
-        wait_for_future(py, result)?;
+        wait_for_future(py, result)??;
 
         Ok(())
     }
@@ -825,9 +825,19 @@ impl PySessionContext {
     }
 
     pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
-        let x = wait_for_future(py, self.ctx.table(name))
+        let res = wait_for_future(py, self.ctx.table(name))
             .map_err(|e| PyKeyError::new_err(e.to_string()))?;
-        Ok(PyDataFrame::new(x))
+        match res {
+            Ok(df) => Ok(PyDataFrame::new(df)),
+            Err(e) => {
+                if let datafusion::error::DataFusionError::Plan(msg) = &e {
+                    if msg.contains("No table named") {
+                        return Err(PyKeyError::new_err(msg.to_string()));
+                    }
+                }
+                Err(py_datafusion_err(e))
+            }
+        }
     }
 
     pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
@@ -865,10 +875,10 @@ impl PySessionContext {
         let df = if let Some(schema) = schema {
             options.schema = Some(&schema.0);
             let result = self.ctx.read_json(path, options);
-            wait_for_future(py, result)?
+            wait_for_future(py, result)??
         } else {
             let result = self.ctx.read_json(path, options);
-            wait_for_future(py, result)?
+            wait_for_future(py, result)??
         };
         Ok(PyDataFrame::new(df))
     }
@@ -915,12 +925,12 @@ impl PySessionContext {
             let paths = path.extract::<Vec<String>>()?;
             let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
             let result = self.ctx.read_csv(paths, options);
-            let df = PyDataFrame::new(wait_for_future(py, result)?);
+            let df = PyDataFrame::new(wait_for_future(py, result)??);
             Ok(df)
         } else {
             let path = path.extract::<String>()?;
             let result = self.ctx.read_csv(path, options);
-            let df = PyDataFrame::new(wait_for_future(py, result)?);
+            let df = PyDataFrame::new(wait_for_future(py, result)??);
             Ok(df)
         }
     }
@@ -958,7 +968,7 @@ impl PySessionContext {
             .collect();
 
         let result = self.ctx.read_parquet(path, options);
-        let df = PyDataFrame::new(wait_for_future(py, result)?);
+        let df = PyDataFrame::new(wait_for_future(py, result)??);
         Ok(df)
     }
 
@@ -978,10 +988,10 @@ impl PySessionContext {
         let df = if let Some(schema) = schema {
             options.schema = Some(&schema.0);
             let read_future = self.ctx.read_avro(path, options);
-            wait_for_future(py, read_future)?
+            wait_for_future(py, read_future)??
         } else {
             let read_future = self.ctx.read_avro(path, options);
-            wait_for_future(py, read_future)?
+            wait_for_future(py, read_future)??
         };
         Ok(PyDataFrame::new(df))
     }
@@ -1021,8 +1031,8 @@ impl PySessionContext {
         let plan = plan.plan.clone();
         let fut: 
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
             rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
-        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
-        Ok(PyRecordBatchStream::new(stream?))
+        let stream = wait_for_future(py, async { 
fut.await.map_err(to_datafusion_err) })???;
+        Ok(PyRecordBatchStream::new(stream))
     }
 }
 
diff --git a/src/dataframe.rs b/src/dataframe.rs
index ece8c4e0..7711a078 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -43,7 +43,7 @@ use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
 use tokio::task::JoinHandle;
 
 use crate::catalog::PyTable;
-use crate::errors::{py_datafusion_err, PyDataFusionError};
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
 use crate::expr::sort_expr::to_sort_expressions;
 use crate::physical_plan::PyExecutionPlan;
 use crate::record_batch::PyRecordBatchStream;
@@ -233,7 +233,7 @@ impl PyDataFrame {
         let (batches, has_more) = wait_for_future(
             py,
             collect_record_batches_to_display(self.df.as_ref().clone(), 
config),
-        )?;
+        )??;
         if batches.is_empty() {
             // This should not be reached, but do it for safety since we index 
into the vector below
             return Ok("No data to display".to_string());
@@ -256,7 +256,7 @@ impl PyDataFrame {
         let (batches, has_more) = wait_for_future(
             py,
             collect_record_batches_to_display(self.df.as_ref().clone(), 
config),
-        )?;
+        )??;
         if batches.is_empty() {
             // This should not be reached, but do it for safety since we index 
into the vector below
             return Ok("No data to display".to_string());
@@ -288,7 +288,7 @@ impl PyDataFrame {
     /// Calculate summary statistics for a DataFrame
     fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
         let df = self.df.as_ref().clone();
-        let stat_df = wait_for_future(py, df.describe())?;
+        let stat_df = wait_for_future(py, df.describe())??;
         Ok(Self::new(stat_df))
     }
 
@@ -391,7 +391,7 @@ impl PyDataFrame {
     /// Unless some order is specified in the plan, there is no
     /// guarantee of the order of the result.
     fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
-        let batches = wait_for_future(py, self.df.as_ref().clone().collect())
+        let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
             .map_err(PyDataFusionError::from)?;
         // cannot use PyResult<Vec<RecordBatch>> return type due to
         // https://github.com/PyO3/pyo3/issues/1813
@@ -400,14 +400,14 @@ impl PyDataFrame {
 
     /// Cache DataFrame.
     fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
-        let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
+        let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
         Ok(Self::new(df))
     }
 
     /// Executes this DataFrame and collects all results into a vector of 
vector of RecordBatch
     /// maintaining the input partitioning.
     fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
-        let batches = wait_for_future(py, 
self.df.as_ref().clone().collect_partitioned())
+        let batches = wait_for_future(py, 
self.df.as_ref().clone().collect_partitioned())?
             .map_err(PyDataFusionError::from)?;
 
         batches
@@ -511,7 +511,7 @@ impl PyDataFrame {
 
     /// Get the execution plan for this `DataFrame`
     fn execution_plan(&self, py: Python) -> 
PyDataFusionResult<PyExecutionPlan> {
-        let plan = wait_for_future(py, 
self.df.as_ref().clone().create_physical_plan())?;
+        let plan = wait_for_future(py, 
self.df.as_ref().clone().create_physical_plan())??;
         Ok(plan.into())
     }
 
@@ -624,7 +624,7 @@ impl PyDataFrame {
                 DataFrameWriteOptions::new(),
                 Some(csv_options),
             ),
-        )?;
+        )??;
         Ok(())
     }
 
@@ -685,7 +685,7 @@ impl PyDataFrame {
                 DataFrameWriteOptions::new(),
                 Option::from(options),
             ),
-        )?;
+        )??;
         Ok(())
     }
 
@@ -697,7 +697,7 @@ impl PyDataFrame {
                 .as_ref()
                 .clone()
                 .write_json(path, DataFrameWriteOptions::new(), None),
-        )?;
+        )??;
         Ok(())
     }
 
@@ -720,7 +720,7 @@ impl PyDataFrame {
         py: Python<'py>,
         requested_schema: Option<Bound<'py, PyCapsule>>,
     ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
-        let mut batches = wait_for_future(py, 
self.df.as_ref().clone().collect())?;
+        let mut batches = wait_for_future(py, 
self.df.as_ref().clone().collect())??;
         let mut schema: Schema = self.df.schema().to_owned().into();
 
         if let Some(schema_capsule) = requested_schema {
@@ -753,8 +753,8 @@ impl PyDataFrame {
         let df = self.df.as_ref().clone();
         let fut: 
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
             rt.spawn(async move { df.execute_stream().await });
-        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
-        Ok(PyRecordBatchStream::new(stream?))
+        let stream = wait_for_future(py, async { 
fut.await.map_err(to_datafusion_err) })???;
+        Ok(PyRecordBatchStream::new(stream))
     }
 
     fn execute_stream_partitioned(&self, py: Python) -> 
PyResult<Vec<PyRecordBatchStream>> {
@@ -763,14 +763,11 @@ impl PyDataFrame {
         let df = self.df.as_ref().clone();
         let fut: 
JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
             rt.spawn(async move { df.execute_stream_partitioned().await });
-        let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
+        let stream = wait_for_future(py, async { 
fut.await.map_err(to_datafusion_err) })?
+            .map_err(py_datafusion_err)?
+            .map_err(py_datafusion_err)?;
 
-        match stream {
-            Ok(batches) => 
Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
-            _ => Err(PyValueError::new_err(
-                "Unable to execute stream partitioned",
-            )),
-        }
+        Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
     }
 
     /// Convert to pandas dataframe with pyarrow
@@ -815,7 +812,7 @@ impl PyDataFrame {
 
     // Executes this DataFrame to get the total number of rows.
     fn count(&self, py: Python) -> PyDataFusionResult<usize> {
-        Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
+        Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
     }
 
     /// Fill null values with a specified value for specific columns
@@ -841,7 +838,7 @@ impl PyDataFrame {
 /// Print DataFrame
 fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
     // Get string representation of record batches
-    let batches = wait_for_future(py, df.collect())?;
+    let batches = wait_for_future(py, df.collect())??;
     let batches_as_string = pretty::pretty_format_batches(&batches);
     let result = match batches_as_string {
         Ok(batch) => format!("DataFrame()\n{batch}"),
diff --git a/src/record_batch.rs b/src/record_batch.rs
index ec61c263..a85f0542 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -63,7 +63,7 @@ impl PyRecordBatchStream {
 impl PyRecordBatchStream {
     fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
         let stream = self.stream.clone();
-        wait_for_future(py, next_stream(stream, true))
+        wait_for_future(py, next_stream(stream, true))?
     }
 
     fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
diff --git a/src/substrait.rs b/src/substrait.rs
index 1fefc0bb..4da3738f 100644
--- a/src/substrait.rs
+++ b/src/substrait.rs
@@ -72,7 +72,7 @@ impl PySubstraitSerializer {
         path: &str,
         py: Python,
     ) -> PyDataFusionResult<()> {
-        wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?;
+        wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))??;
         Ok(())
     }
 
@@ -94,19 +94,20 @@ impl PySubstraitSerializer {
         ctx: PySessionContext,
         py: Python,
     ) -> PyDataFusionResult<PyObject> {
-        let proto_bytes: Vec<u8> = wait_for_future(py, 
serializer::serialize_bytes(sql, &ctx.ctx))?;
+        let proto_bytes: Vec<u8> =
+            wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))??;
         Ok(PyBytes::new(py, &proto_bytes).into())
     }
 
     #[staticmethod]
     pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult<PyPlan> {
-        let plan = wait_for_future(py, serializer::deserialize(path))?;
+        let plan = wait_for_future(py, serializer::deserialize(path))??;
         Ok(PyPlan { plan: *plan })
     }
 
     #[staticmethod]
     pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) -> 
PyDataFusionResult<PyPlan> {
-        let plan = wait_for_future(py, 
serializer::deserialize_bytes(proto_bytes))?;
+        let plan = wait_for_future(py, 
serializer::deserialize_bytes(proto_bytes))??;
         Ok(PyPlan { plan: *plan })
     }
 }
@@ -143,7 +144,7 @@ impl PySubstraitConsumer {
     ) -> PyDataFusionResult<PyLogicalPlan> {
         let session_state = ctx.ctx.state();
         let result = consumer::from_substrait_plan(&session_state, &plan.plan);
-        let logical_plan = wait_for_future(py, result)?;
+        let logical_plan = wait_for_future(py, result)??;
         Ok(PyLogicalPlan::new(logical_plan))
     }
 }
diff --git a/src/utils.rs b/src/utils.rs
index 0a24ab25..90d65438 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -15,19 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use crate::common::data_type::PyScalarValue;
-use crate::errors::{PyDataFusionError, PyDataFusionResult};
-use crate::TokioRuntime;
-use datafusion::common::ScalarValue;
-use datafusion::execution::context::SessionContext;
-use datafusion::logical_expr::Volatility;
-use pyo3::exceptions::PyValueError;
+use crate::{
+    common::data_type::PyScalarValue,
+    errors::{PyDataFusionError, PyDataFusionResult},
+    TokioRuntime,
+};
+use datafusion::{
+    common::ScalarValue, execution::context::SessionContext, 
logical_expr::Volatility,
+};
 use pyo3::prelude::*;
-use pyo3::types::PyCapsule;
-use std::future::Future;
-use std::sync::OnceLock;
-use tokio::runtime::Runtime;
-
+use pyo3::{exceptions::PyValueError, types::PyCapsule};
+use std::{future::Future, sync::OnceLock, time::Duration};
+use tokio::{runtime::Runtime, time::sleep};
 /// Utility to get the Tokio Runtime from Python
 #[inline]
 pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
@@ -47,14 +46,31 @@ pub(crate) fn get_global_ctx() -> &'static SessionContext {
     CTX.get_or_init(SessionContext::new)
 }
 
-/// Utility to collect rust futures with GIL released
-pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
+/// Utility to collect rust futures with GIL released and respond to
+/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
+/// received while the future is running, the future is aborted and the
+/// corresponding Python exception is raised.
+pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
 where
     F: Future + Send,
     F::Output: Send,
 {
     let runtime: &Runtime = &get_tokio_runtime().0;
-    py.allow_threads(|| runtime.block_on(f))
+    const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
+
+    py.allow_threads(|| {
+        runtime.block_on(async {
+            tokio::pin!(fut);
+            loop {
+                tokio::select! {
+                    res = &mut fut => break Ok(res),
+                    _ = sleep(INTERVAL_CHECK_SIGNALS) => {
+                        Python::with_gil(|py| py.check_signals())?;
+                    }
+                }
+            }
+        })
+    })
 }
 
 pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to