This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 8246631bfa fix: Implement `reset_state` for `LazyMemoryExec` (#19362)
8246631bfa is described below

commit 8246631bfacb771e2dc55c3a55d2c6e4358e325f
Author: Nuno Faria <[email protected]>
AuthorDate: Sun Dec 28 21:12:17 2025 +0000

    fix: Implement `reset_state` for `LazyMemoryExec` (#19362)
    
    ## Which issue does this PR close?
    
    - N/A.
    
    ## Rationale for this change
    
    Implement `ExecutionPlan::reset_state` for `LazyMemoryExec` (used in
    e.g. `generate_series`) so it can be reused across executions.
    
    ## What changes are included in this PR?
    
    - Implemented `ExecutionPlan::reset_state` for `LazyMemoryExec`.
    - Added `reset_state` to the `LazyBatchGenerator` trait and implemented
    for the structs that implement it.
    - Added unit tests.
    
    ## Are these changes tested?
    
    Yes.
    
    ## Are there any user-facing changes?
    
    Yes, new API method in the `LazyBatchGenerator` trait.
---
 datafusion/core/tests/execution/coop.rs            | 13 +++++-
 datafusion/functions-table/src/generate_series.rs  | 47 +++++++++++++++++++
 datafusion/physical-plan/src/memory.rs             | 54 ++++++++++++++++++++++
 .../sqllogictest/test_files/table_functions.slt    | 24 ++++++++++
 4 files changed, 136 insertions(+), 2 deletions(-)

diff --git a/datafusion/core/tests/execution/coop.rs 
b/datafusion/core/tests/execution/coop.rs
index ddc13fd5d9..27dacf598c 100644
--- a/datafusion/core/tests/execution/coop.rs
+++ b/datafusion/core/tests/execution/coop.rs
@@ -64,13 +64,14 @@ use std::time::Duration;
 use tokio::runtime::{Handle, Runtime};
 use tokio::select;
 
-#[derive(Debug)]
+#[derive(Debug, Clone)]
 struct RangeBatchGenerator {
     schema: SchemaRef,
     value_range: Range<i64>,
     boundedness: Boundedness,
     batch_size: usize,
     poll_count: usize,
+    original_range: Range<i64>,
 }
 
 impl std::fmt::Display for RangeBatchGenerator {
@@ -110,6 +111,13 @@ impl LazyBatchGenerator for RangeBatchGenerator {
             RecordBatch::try_new(Arc::clone(&self.schema), 
vec![Arc::new(array)])?;
         Ok(Some(batch))
     }
+
+    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
+        let mut new = self.clone();
+        new.poll_count = 0;
+        new.value_range = new.original_range.clone();
+        Arc::new(RwLock::new(new))
+    }
 }
 
 fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec 
{
@@ -139,9 +147,10 @@ fn make_lazy_exec_with_range(
     let batch_gen = RangeBatchGenerator {
         schema: Arc::clone(&schema),
         boundedness,
-        value_range: range,
+        value_range: range.clone(),
         batch_size: 8192,
         poll_count: 0,
+        original_range: range,
     };
 
     // Wrap the generator in a trait object behind Arc<RwLock<_>>
diff --git a/datafusion/functions-table/src/generate_series.rs 
b/datafusion/functions-table/src/generate_series.rs
index 9e58e9d0d0..b806798bce 100644
--- a/datafusion/functions-table/src/generate_series.rs
+++ b/datafusion/functions-table/src/generate_series.rs
@@ -56,6 +56,10 @@ impl LazyBatchGenerator for Empty {
     fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
         Ok(None)
     }
+
+    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
+        Arc::new(RwLock::new(Empty { name: self.name }))
+    }
 }
 
 impl fmt::Display for Empty {
@@ -398,6 +402,12 @@ impl<T: SeriesValue> LazyBatchGenerator for 
GenericSeriesState<T> {
         let batch = RecordBatch::try_new(Arc::clone(&self.schema), 
vec![array])?;
         Ok(Some(batch))
     }
+
+    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
+        let mut new = self.clone();
+        new.current = new.start.clone();
+        Arc::new(RwLock::new(new))
+    }
 }
 
 impl<T: SeriesValue> fmt::Display for GenericSeriesState<T> {
@@ -779,3 +789,40 @@ impl TableFunctionImpl for RangeFunc {
         impl_func.call(exprs)
     }
 }
+
+#[cfg(test)]
+mod generate_series_tests {
+    use std::sync::Arc;
+
+    use arrow::datatypes::{DataType, Field, Schema};
+    use datafusion_common::Result;
+    use datafusion_physical_plan::memory::LazyBatchGenerator;
+
+    use crate::generate_series::GenericSeriesState;
+
+    #[test]
+    fn test_generic_series_state_reset() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int64, false)]));
+        let mut state = GenericSeriesState::<i64> {
+            schema,
+            start: 1,
+            end: 5,
+            step: 1,
+            current: 1,
+            batch_size: 8192,
+            include_end: true,
+            name: "test",
+        };
+        let batch = state.generate_next_batch()?.expect("missing batch");
+
+        let state_reset = state.reset_state();
+        let reset_batch = state_reset
+            .write()
+            .generate_next_batch()?
+            .expect("missing reset batch");
+
+        assert_eq!(batch, reset_batch);
+
+        Ok(())
+    }
+}
diff --git a/datafusion/physical-plan/src/memory.rs 
b/datafusion/physical-plan/src/memory.rs
index 65a3fe575e..4a406ca648 100644
--- a/datafusion/physical-plan/src/memory.rs
+++ b/datafusion/physical-plan/src/memory.rs
@@ -144,6 +144,9 @@ pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + 
fmt::Display {
 
     /// Generate the next batch, return `None` when no more batches are 
available
     fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>>;
+
+    /// Returns a new instance with the state reset.
+    fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>>;
 }
 
 /// Execution plan for lazy in-memory batches of data
@@ -352,6 +355,21 @@ impl ExecutionPlan for LazyMemoryExec {
     fn statistics(&self) -> Result<Statistics> {
         Ok(Statistics::new_unknown(&self.schema))
     }
+
+    fn reset_state(self: Arc<Self>) -> Result<Arc<dyn ExecutionPlan>> {
+        let generators = self
+            .generators()
+            .iter()
+            .map(|g| g.read().reset_state())
+            .collect::<Vec<_>>();
+        Ok(Arc::new(LazyMemoryExec {
+            schema: Arc::clone(&self.schema),
+            batch_generators: generators,
+            cache: self.cache.clone(),
+            metrics: ExecutionPlanMetricsSet::new(),
+            projection: self.projection.clone(),
+        }))
+    }
 }
 
 /// Stream that generates record batches on demand
@@ -450,6 +468,15 @@ mod lazy_memory_tests {
                 vec![Arc::new(array)],
             )?))
         }
+
+        fn reset_state(&self) -> Arc<RwLock<dyn LazyBatchGenerator>> {
+            Arc::new(RwLock::new(TestGenerator {
+                counter: 0,
+                max_batches: self.max_batches,
+                batch_size: self.batch_size,
+                schema: Arc::clone(&self.schema),
+            }))
+        }
     }
 
     #[tokio::test]
@@ -568,4 +595,31 @@ mod lazy_memory_tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_lazy_memory_exec_reset_state() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int64, false)]));
+        let generator = TestGenerator {
+            counter: 0,
+            max_batches: 3,
+            batch_size: 2,
+            schema: Arc::clone(&schema),
+        };
+
+        let exec = Arc::new(LazyMemoryExec::try_new(
+            schema,
+            vec![Arc::new(RwLock::new(generator))],
+        )?);
+        let stream = exec.execute(0, Arc::new(TaskContext::default()))?;
+        let batches = collect(stream).await?;
+
+        let exec_reset = exec.reset_state()?;
+        let stream = exec_reset.execute(0, Arc::new(TaskContext::default()))?;
+        let batches_reset = collect(stream).await?;
+
+        // if the reset_state is not correct, the batches_reset will be empty
+        assert_eq!(batches, batches_reset);
+
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/table_functions.slt 
b/datafusion/sqllogictest/test_files/table_functions.slt
index c843400efc..cf8a091880 100644
--- a/datafusion/sqllogictest/test_files/table_functions.slt
+++ b/datafusion/sqllogictest/test_files/table_functions.slt
@@ -509,3 +509,27 @@ SELECT c, f.*  FROM json_table, LATERAL 
generate_series(1,2) f;
 1 2
 2 1
 2 2
+
+
+# Test generate_series in a recursive CTE to ensure the state is correctly 
reset
+query I rowsort
+WITH RECURSIVE t AS (
+    SELECT 1 i
+    UNION ALL
+    SELECT g.i
+    FROM generate_series(1, 1) g(i), t
+)
+SELECT *
+FROM t
+LIMIT 10;
+----
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to