This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new dc0d35a2 Add Interruptible Query Execution in Jupyter via
KeyboardInterrupt Support (#1141)
dc0d35a2 is described below
commit dc0d35a2cdb47f96a04c98d5cf1e7e79b0ba1467
Author: kosiew <[email protected]>
AuthorDate: Mon Jun 16 19:53:40 2025 +0800
Add Interruptible Query Execution in Jupyter via KeyboardInterrupt Support
(#1141)
* fix: enhance error handling in async wait_for_future function
* feat: implement async execution for execution plans in PySessionContext
* fix: improve error message for execution failures in PySessionContext
* fix: enhance error handling and improve execution plan retrieval in
PyDataFrame
* fix: ensure 'static lifetime for futures in wait_for_future function
* fix: handle potential errors when caching DataFrame and retrieving
execution plan
* fix: flatten batches in PyDataFrame to ensure proper schema conversion
* fix: correct error handling in batch processing for schema conversion
* fix: flatten nested structure in PyDataFrame to ensure proper RecordBatch
iteration
* fix: improve error handling in PyDataFrame stream execution
* fix: add utility to get Tokio Runtime with time enabled and update
wait_for_future to use it
* fix: store result of converting RecordBatches to PyArrow for debugging
* fix: handle error from wait_for_future in PyDataFrame collect method
* fix: propagate error from wait_for_future in collect_partitioned method
* fix: enable IO in Tokio runtime with time support
* main register_listing_table
* Revert "main register_listing_table"
This reverts commit 52a5efe2001455a3ad881968d468e5c7538e1ced.
* fix: propagate error correctly from wait_for_future in PySessionContext
methods
* fix: simplify error handling in PySessionContext by unwrapping
wait_for_future result
* test: add interruption handling test for long-running queries in
DataFusion
* test: move test_collect_interrupted to test_dataframe.py
* fix: add const for interval in wait_for_future utility
* fix: use get_tokio_runtime instead of the custom get_runtime
* Revert "fix: use get_tokio_runtime instead of the custom get_runtime"
This reverts commit ca2d89289d0a702bbb38f34e88fb78ad61d20647.
* fix: use get_tokio_runtime instead of the custom get_runtime
* .
* Revert "."
This reverts commit b8ce3e446b74aac7a76f1cc8ce6501b453d4f13c.
* fix: improve query interruption handling in test_collect_interrupted
* fix: ensure proper handling of query interruption in
test_collect_interrupted
* fix: improve error handling in database table retrieval
* refactor: add helper for async move
* Revert "refactor: add helper for async move"
This reverts commit faabf6dd90ac505934e7cb6dc3b69fddbe89e661.
* move py_err_to_datafusion_err to errors.rs
* add create_csv_read_options
* fix
* create_csv_read_options -> PyDataFusionResult
* revert to before create_csv_read_options
* refactor: simplify file compression type parsing in PySessionContext
* fix: parse_compression_type once only
* add create_ndjson_read_options
* refactor comment for clarity in wait_for_future function
* refactor wait_for_future to avoid spawn
* remove unused py_err_to_datafusion_err function
* add comment to clarify error handling in next method of
PyRecordBatchStream
* handle error from wait_for_future in PySubstraitSerializer
* clarify comment on future pinning in wait_for_future function
* refactor wait_for_future to use Duration for signal check interval
* handle error from wait_for_future in count method of PyDataFrame
* fix ruff errors
* fix clippy errors
* remove unused get_and_enter_tokio_runtime function and simplify
wait_for_future
* Refactor async handling in PySessionContext and PyDataFrame
- Simplified async handling by removing unnecessary cloning of strings and
context in various methods.
- Streamlined the use of `wait_for_future` to directly handle futures
without intermediate variables.
- Improved error handling by directly propagating results from futures.
- Enhanced readability by reducing boilerplate code in methods related to
reading and writing data.
- Updated the `wait_for_future` function to improve signal checking and
future handling.
* Organize imports in utils.rs for improved readability
* map_err instead of panic
* Fix error handling in async stream execution for PySessionContext and
PyDataFrame
---
python/tests/test_dataframe.py | 121 +++++++++++++++++++++++++++++++++++++++++
src/catalog.rs | 2 +-
src/context.rs | 52 +++++++++++-------
src/dataframe.rs | 43 +++++++--------
src/record_batch.rs | 2 +-
src/substrait.rs | 11 ++--
src/utils.rs | 46 +++++++++++-----
7 files changed, 211 insertions(+), 66 deletions(-)
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index dd5f962b..64220ce9 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -14,9 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import ctypes
import datetime
import os
import re
+import threading
+import time
from typing import Any
import pyarrow as pa
@@ -2060,3 +2063,121 @@ def test_fill_null_all_null_column(ctx):
# Check that all nulls were filled
result = filled_df.collect()[0]
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
+
+
+def test_collect_interrupted():
+ """Test that a long-running query can be interrupted with Ctrl-C.
+
+ This test simulates a Ctrl-C keyboard interrupt by raising a
KeyboardInterrupt
+ exception in the main thread during a long-running query execution.
+ """
+ # Create a context and a DataFrame with a query that will run for a while
+ ctx = SessionContext()
+
+ # Create a recursive computation that will run for some time
+ batches = []
+ for i in range(10):
+ batch = pa.RecordBatch.from_arrays(
+ [
+ pa.array(list(range(i * 1000, (i + 1) * 1000))),
+ pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) *
1000)]),
+ ],
+ names=["a", "b"],
+ )
+ batches.append(batch)
+
+ # Register tables
+ ctx.register_record_batches("t1", [batches])
+ ctx.register_record_batches("t2", [batches])
+
+ # Create a large join operation that will take time to process
+ df = ctx.sql("""
+ WITH t1_expanded AS (
+ SELECT
+ a,
+ b,
+ CAST(a AS DOUBLE) / 1.5 AS c,
+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
+ FROM t1
+ CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
+ ),
+ t2_expanded AS (
+ SELECT
+ a,
+ b,
+ CAST(a AS DOUBLE) * 2.5 AS e,
+ CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
+ FROM t2
+ CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
+ )
+ SELECT
+ t1.a, t1.b, t1.c, t1.d,
+ t2.a AS a2, t2.b AS b2, t2.e, t2.f
+ FROM t1_expanded t1
+ JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
+ WHERE t1.a > 100 AND t2.a > 100
+ """)
+
+ # Flag to track if the query was interrupted
+ interrupted = False
+ interrupt_error = None
+ main_thread = threading.main_thread()
+
+ # Shared flag to indicate query execution has started
+ query_started = threading.Event()
+ max_wait_time = 5.0 # Maximum wait time in seconds
+
+ # This function will be run in a separate thread and will raise
+ # KeyboardInterrupt in the main thread
+ def trigger_interrupt():
+ """Poll for query start, then raise KeyboardInterrupt in the main
thread"""
+ # Poll for query to start with small sleep intervals
+ start_time = time.time()
+ while not query_started.is_set():
+ time.sleep(0.1) # Small sleep between checks
+ if time.time() - start_time > max_wait_time:
+ msg = f"Query did not start within {max_wait_time} seconds"
+ raise RuntimeError(msg)
+
+ # Check if thread ID is available
+ thread_id = main_thread.ident
+ if thread_id is None:
+ msg = "Cannot get main thread ID"
+ raise RuntimeError(msg)
+
+ # Use ctypes to raise exception in main thread
+ exception = ctypes.py_object(KeyboardInterrupt)
+ res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
+ ctypes.c_long(thread_id), exception
+ )
+ if res != 1:
+ # If res is 0, the thread ID was invalid
+ # If res > 1, we modified multiple threads
+ ctypes.pythonapi.PyThreadState_SetAsyncExc(
+ ctypes.c_long(thread_id), ctypes.py_object(0)
+ )
+ msg = "Failed to raise KeyboardInterrupt in main thread"
+ raise RuntimeError(msg)
+
+ # Start a thread to trigger the interrupt
+ interrupt_thread = threading.Thread(target=trigger_interrupt)
+ # we mark as daemon so the test process can exit even if this thread
doesn't finish
+ interrupt_thread.daemon = True
+ interrupt_thread.start()
+
+ # Execute the query and expect it to be interrupted
+ try:
+ # Signal that we're about to start the query
+ query_started.set()
+ df.collect()
+ except KeyboardInterrupt:
+ interrupted = True
+ except Exception as e:
+ interrupt_error = e
+
+ # Assert that the query was interrupted properly
+ if not interrupted:
+ pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
+
+ # Make sure the interrupt thread has finished
+ interrupt_thread.join(timeout=1.0)
diff --git a/src/catalog.rs b/src/catalog.rs
index 1e189a5a..83f8d08c 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -97,7 +97,7 @@ impl PyDatabase {
}
fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
- if let Some(table) = wait_for_future(py, self.database.table(name))? {
+ if let Some(table) = wait_for_future(py, self.database.table(name))?? {
Ok(PyTable::new(table))
} else {
Err(PyDataFusionError::Common(format!(
diff --git a/src/context.rs b/src/context.rs
index cc3d8e8e..b0af566e 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, PyDataFusionResult};
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
@@ -375,7 +375,7 @@ impl PySessionContext {
None => {
let state = self.ctx.state();
let schema = options.infer_schema(&state, &table_path);
- wait_for_future(py, schema)?
+ wait_for_future(py, schema)??
}
};
let config = ListingTableConfig::new(table_path)
@@ -400,7 +400,7 @@ impl PySessionContext {
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
pub fn sql(&mut self, query: &str, py: Python) ->
PyDataFusionResult<PyDataFrame> {
let result = self.ctx.sql(query);
- let df = wait_for_future(py, result)?;
+ let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}
@@ -417,7 +417,7 @@ impl PySessionContext {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
- let df = wait_for_future(py, result)?;
+ let df = wait_for_future(py, result)??;
Ok(PyDataFrame::new(df))
}
@@ -451,7 +451,7 @@ impl PySessionContext {
self.ctx.register_table(&*table_name, Arc::new(table))?;
- let table = wait_for_future(py, self._table(&table_name))?;
+ let table = wait_for_future(py, self._table(&table_name))??;
let df = PyDataFrame::new(table);
Ok(df)
@@ -650,7 +650,7 @@ impl PySessionContext {
.collect();
let result = self.ctx.register_parquet(name, path, options);
- wait_for_future(py, result)?;
+ wait_for_future(py, result)??;
Ok(())
}
@@ -693,11 +693,11 @@ impl PySessionContext {
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let result = self.register_csv_from_multiple_paths(name, paths,
options);
- wait_for_future(py, result)?;
+ wait_for_future(py, result)??;
} else {
let path = path.extract::<String>()?;
let result = self.ctx.register_csv(name, &path, options);
- wait_for_future(py, result)?;
+ wait_for_future(py, result)??;
}
Ok(())
@@ -734,7 +734,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_json(name, path, options);
- wait_for_future(py, result)?;
+ wait_for_future(py, result)??;
Ok(())
}
@@ -764,7 +764,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_avro(name, path, options);
- wait_for_future(py, result)?;
+ wait_for_future(py, result)??;
Ok(())
}
@@ -825,9 +825,19 @@ impl PySessionContext {
}
pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
- let x = wait_for_future(py, self.ctx.table(name))
+ let res = wait_for_future(py, self.ctx.table(name))
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
- Ok(PyDataFrame::new(x))
+ match res {
+ Ok(df) => Ok(PyDataFrame::new(df)),
+ Err(e) => {
+ if let datafusion::error::DataFusionError::Plan(msg) = &e {
+ if msg.contains("No table named") {
+ return Err(PyKeyError::new_err(msg.to_string()));
+ }
+ }
+ Err(py_datafusion_err(e))
+ }
+ }
}
pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
@@ -865,10 +875,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let result = self.ctx.read_json(path, options);
- wait_for_future(py, result)?
+ wait_for_future(py, result)??
} else {
let result = self.ctx.read_json(path, options);
- wait_for_future(py, result)?
+ wait_for_future(py, result)??
};
Ok(PyDataFrame::new(df))
}
@@ -915,12 +925,12 @@ impl PySessionContext {
let paths = path.extract::<Vec<String>>()?;
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
let result = self.ctx.read_csv(paths, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
} else {
let path = path.extract::<String>()?;
let result = self.ctx.read_csv(path, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
}
}
@@ -958,7 +968,7 @@ impl PySessionContext {
.collect();
let result = self.ctx.read_parquet(path, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)??);
Ok(df)
}
@@ -978,10 +988,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let read_future = self.ctx.read_avro(path, options);
- wait_for_future(py, read_future)?
+ wait_for_future(py, read_future)??
} else {
let read_future = self.ctx.read_avro(path, options);
- wait_for_future(py, read_future)?
+ wait_for_future(py, read_future)??
};
Ok(PyDataFrame::new(df))
}
@@ -1021,8 +1031,8 @@ impl PySessionContext {
let plan = plan.plan.clone();
let fut:
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
- let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
- Ok(PyRecordBatchStream::new(stream?))
+ let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })???;
+ Ok(PyRecordBatchStream::new(stream))
}
}
diff --git a/src/dataframe.rs b/src/dataframe.rs
index ece8c4e0..7711a078 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -43,7 +43,7 @@ use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
use tokio::task::JoinHandle;
use crate::catalog::PyTable;
-use crate::errors::{py_datafusion_err, PyDataFusionError};
+use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError};
use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
@@ -233,7 +233,7 @@ impl PyDataFrame {
let (batches, has_more) = wait_for_future(
py,
collect_record_batches_to_display(self.df.as_ref().clone(),
config),
- )?;
+ )??;
if batches.is_empty() {
// This should not be reached, but do it for safety since we index
into the vector below
return Ok("No data to display".to_string());
@@ -256,7 +256,7 @@ impl PyDataFrame {
let (batches, has_more) = wait_for_future(
py,
collect_record_batches_to_display(self.df.as_ref().clone(),
config),
- )?;
+ )??;
if batches.is_empty() {
// This should not be reached, but do it for safety since we index
into the vector below
return Ok("No data to display".to_string());
@@ -288,7 +288,7 @@ impl PyDataFrame {
/// Calculate summary statistics for a DataFrame
fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone();
- let stat_df = wait_for_future(py, df.describe())?;
+ let stat_df = wait_for_future(py, df.describe())??;
Ok(Self::new(stat_df))
}
@@ -391,7 +391,7 @@ impl PyDataFrame {
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
- let batches = wait_for_future(py, self.df.as_ref().clone().collect())
+ let batches = wait_for_future(py, self.df.as_ref().clone().collect())?
.map_err(PyDataFusionError::from)?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
@@ -400,14 +400,14 @@ impl PyDataFrame {
/// Cache DataFrame.
fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
- let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
+ let df = wait_for_future(py, self.df.as_ref().clone().cache())??;
Ok(Self::new(df))
}
/// Executes this DataFrame and collects all results into a vector of
vector of RecordBatch
/// maintaining the input partitioning.
fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
- let batches = wait_for_future(py,
self.df.as_ref().clone().collect_partitioned())
+ let batches = wait_for_future(py,
self.df.as_ref().clone().collect_partitioned())?
.map_err(PyDataFusionError::from)?;
batches
@@ -511,7 +511,7 @@ impl PyDataFrame {
/// Get the execution plan for this `DataFrame`
fn execution_plan(&self, py: Python) ->
PyDataFusionResult<PyExecutionPlan> {
- let plan = wait_for_future(py,
self.df.as_ref().clone().create_physical_plan())?;
+ let plan = wait_for_future(py,
self.df.as_ref().clone().create_physical_plan())??;
Ok(plan.into())
}
@@ -624,7 +624,7 @@ impl PyDataFrame {
DataFrameWriteOptions::new(),
Some(csv_options),
),
- )?;
+ )??;
Ok(())
}
@@ -685,7 +685,7 @@ impl PyDataFrame {
DataFrameWriteOptions::new(),
Option::from(options),
),
- )?;
+ )??;
Ok(())
}
@@ -697,7 +697,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.write_json(path, DataFrameWriteOptions::new(), None),
- )?;
+ )??;
Ok(())
}
@@ -720,7 +720,7 @@ impl PyDataFrame {
py: Python<'py>,
requested_schema: Option<Bound<'py, PyCapsule>>,
) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
- let mut batches = wait_for_future(py,
self.df.as_ref().clone().collect())?;
+ let mut batches = wait_for_future(py,
self.df.as_ref().clone().collect())??;
let mut schema: Schema = self.df.schema().to_owned().into();
if let Some(schema_capsule) = requested_schema {
@@ -753,8 +753,8 @@ impl PyDataFrame {
let df = self.df.as_ref().clone();
let fut:
JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
rt.spawn(async move { df.execute_stream().await });
- let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
- Ok(PyRecordBatchStream::new(stream?))
+ let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })???;
+ Ok(PyRecordBatchStream::new(stream))
}
fn execute_stream_partitioned(&self, py: Python) ->
PyResult<Vec<PyRecordBatchStream>> {
@@ -763,14 +763,11 @@ impl PyDataFrame {
let df = self.df.as_ref().clone();
let fut:
JoinHandle<datafusion::common::Result<Vec<SendableRecordBatchStream>>> =
rt.spawn(async move { df.execute_stream_partitioned().await });
- let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
+ let stream = wait_for_future(py, async {
fut.await.map_err(to_datafusion_err) })?
+ .map_err(py_datafusion_err)?
+ .map_err(py_datafusion_err)?;
- match stream {
- Ok(batches) =>
Ok(batches.into_iter().map(PyRecordBatchStream::new).collect()),
- _ => Err(PyValueError::new_err(
- "Unable to execute stream partitioned",
- )),
- }
+ Ok(stream.into_iter().map(PyRecordBatchStream::new).collect())
}
/// Convert to pandas dataframe with pyarrow
@@ -815,7 +812,7 @@ impl PyDataFrame {
// Executes this DataFrame to get the total number of rows.
fn count(&self, py: Python) -> PyDataFusionResult<usize> {
- Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
+ Ok(wait_for_future(py, self.df.as_ref().clone().count())??)
}
/// Fill null values with a specified value for specific columns
@@ -841,7 +838,7 @@ impl PyDataFrame {
/// Print DataFrame
fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
// Get string representation of record batches
- let batches = wait_for_future(py, df.collect())?;
+ let batches = wait_for_future(py, df.collect())??;
let batches_as_string = pretty::pretty_format_batches(&batches);
let result = match batches_as_string {
Ok(batch) => format!("DataFrame()\n{batch}"),
diff --git a/src/record_batch.rs b/src/record_batch.rs
index ec61c263..a85f0542 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -63,7 +63,7 @@ impl PyRecordBatchStream {
impl PyRecordBatchStream {
fn next(&mut self, py: Python) -> PyResult<PyRecordBatch> {
let stream = self.stream.clone();
- wait_for_future(py, next_stream(stream, true))
+ wait_for_future(py, next_stream(stream, true))?
}
fn __next__(&mut self, py: Python) -> PyResult<PyRecordBatch> {
diff --git a/src/substrait.rs b/src/substrait.rs
index 1fefc0bb..4da3738f 100644
--- a/src/substrait.rs
+++ b/src/substrait.rs
@@ -72,7 +72,7 @@ impl PySubstraitSerializer {
path: &str,
py: Python,
) -> PyDataFusionResult<()> {
- wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?;
+ wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))??;
Ok(())
}
@@ -94,19 +94,20 @@ impl PySubstraitSerializer {
ctx: PySessionContext,
py: Python,
) -> PyDataFusionResult<PyObject> {
- let proto_bytes: Vec<u8> = wait_for_future(py,
serializer::serialize_bytes(sql, &ctx.ctx))?;
+ let proto_bytes: Vec<u8> =
+ wait_for_future(py, serializer::serialize_bytes(sql, &ctx.ctx))??;
Ok(PyBytes::new(py, &proto_bytes).into())
}
#[staticmethod]
pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult<PyPlan> {
- let plan = wait_for_future(py, serializer::deserialize(path))?;
+ let plan = wait_for_future(py, serializer::deserialize(path))??;
Ok(PyPlan { plan: *plan })
}
#[staticmethod]
pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) ->
PyDataFusionResult<PyPlan> {
- let plan = wait_for_future(py,
serializer::deserialize_bytes(proto_bytes))?;
+ let plan = wait_for_future(py,
serializer::deserialize_bytes(proto_bytes))??;
Ok(PyPlan { plan: *plan })
}
}
@@ -143,7 +144,7 @@ impl PySubstraitConsumer {
) -> PyDataFusionResult<PyLogicalPlan> {
let session_state = ctx.ctx.state();
let result = consumer::from_substrait_plan(&session_state, &plan.plan);
- let logical_plan = wait_for_future(py, result)?;
+ let logical_plan = wait_for_future(py, result)??;
Ok(PyLogicalPlan::new(logical_plan))
}
}
diff --git a/src/utils.rs b/src/utils.rs
index 0a24ab25..90d65438 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -15,19 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-use crate::common::data_type::PyScalarValue;
-use crate::errors::{PyDataFusionError, PyDataFusionResult};
-use crate::TokioRuntime;
-use datafusion::common::ScalarValue;
-use datafusion::execution::context::SessionContext;
-use datafusion::logical_expr::Volatility;
-use pyo3::exceptions::PyValueError;
+use crate::{
+ common::data_type::PyScalarValue,
+ errors::{PyDataFusionError, PyDataFusionResult},
+ TokioRuntime,
+};
+use datafusion::{
+ common::ScalarValue, execution::context::SessionContext,
logical_expr::Volatility,
+};
use pyo3::prelude::*;
-use pyo3::types::PyCapsule;
-use std::future::Future;
-use std::sync::OnceLock;
-use tokio::runtime::Runtime;
-
+use pyo3::{exceptions::PyValueError, types::PyCapsule};
+use std::{future::Future, sync::OnceLock, time::Duration};
+use tokio::{runtime::Runtime, time::sleep};
/// Utility to get the Tokio Runtime from Python
#[inline]
pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
@@ -47,14 +46,31 @@ pub(crate) fn get_global_ctx() -> &'static SessionContext {
CTX.get_or_init(SessionContext::new)
}
-/// Utility to collect rust futures with GIL released
-pub fn wait_for_future<F>(py: Python, f: F) -> F::Output
+/// Utility to collect rust futures with GIL released and respond to
+/// Python interrupts such as ``KeyboardInterrupt``. If a signal is
+/// received while the future is running, the future is aborted and the
+/// corresponding Python exception is raised.
+pub fn wait_for_future<F>(py: Python, fut: F) -> PyResult<F::Output>
where
F: Future + Send,
F::Output: Send,
{
let runtime: &Runtime = &get_tokio_runtime().0;
- py.allow_threads(|| runtime.block_on(f))
+ const INTERVAL_CHECK_SIGNALS: Duration = Duration::from_millis(1_000);
+
+ py.allow_threads(|| {
+ runtime.block_on(async {
+ tokio::pin!(fut);
+ loop {
+ tokio::select! {
+ res = &mut fut => break Ok(res),
+ _ = sleep(INTERVAL_CHECK_SIGNALS) => {
+ Python::with_gil(|py| py.check_signals())?;
+ }
+ }
+ }
+ })
+ })
}
pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]