This is an automated email from the ASF dual-hosted git repository.

jeffreyvo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 61e96054b3 Add `SessionContext::record_batches` (#9197)
61e96054b3 is described below

commit 61e96054b3db4ce5b139e0056a9989b2f12feae4
Author: Lordworms <[email protected]>
AuthorDate: Wed Feb 14 03:53:42 2024 -0600

    Add `SessionContext::record_batches` (#9197)
    
    * feat: issue #9157 adding record_batches for Vec<BatchRecord>
    
    * fix bugs
    
    * optimize code and tests
    
    * optimize test
    
    * optimize tests
    
    * abandon useless schema
    
    * collect into a single batches
---
 datafusion/core/src/execution/context/mod.rs | 25 ++++++++++-
 datafusion/core/tests/dataframe/mod.rs       | 66 ++++++++++++++++++++++++++++
 2 files changed, 90 insertions(+), 1 deletion(-)

diff --git a/datafusion/core/src/execution/context/mod.rs 
b/datafusion/core/src/execution/context/mod.rs
index c81e011b45..e2854fcd2b 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -36,6 +36,7 @@ use crate::{
     optimizer::optimizer::Optimizer,
     physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
 };
+use arrow_schema::Schema;
 use datafusion_common::{
     alias::AliasGenerator,
     exec_err, not_impl_err, plan_datafusion_err, plan_err,
@@ -934,7 +935,29 @@ impl SessionContext {
             .build()?,
         ))
     }
-
+    /// Create a [`DataFrame`] for reading a [`Vec[`RecordBatch`]`]
+    pub fn read_batches(
+        &self,
+        batches: impl IntoIterator<Item = RecordBatch>,
+    ) -> Result<DataFrame> {
+        // check schema uniqueness
+        let mut batches = batches.into_iter().peekable();
+        let schema = if let Some(batch) = batches.peek() {
+            batch.schema().clone()
+        } else {
+            Arc::new(Schema::empty())
+        };
+        let provider = MemTable::try_new(schema, vec![batches.collect()])?;
+        Ok(DataFrame::new(
+            self.state(),
+            LogicalPlanBuilder::scan(
+                UNNAMED_TABLE,
+                provider_as_source(Arc::new(provider)),
+                None,
+            )?
+            .build()?,
+        ))
+    }
     /// Registers a [`ListingTable`] that can assemble multiple files
     /// from locations in an [`ObjectStore`] instance into a single
     /// table.
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index dc347ed9c8..f650e9e39d 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -28,6 +28,7 @@ use arrow::{
     },
     record_batch::RecordBatch,
 };
+use arrow_array::Float32Array;
 use arrow_schema::ArrowError;
 use std::sync::Arc;
 
@@ -1431,6 +1432,71 @@ async fn unnest_analyze_metrics() -> Result<()> {
 
     Ok(())
 }
+#[tokio::test]
+async fn test_read_batches() -> Result<()> {
+    let config = SessionConfig::new();
+    let runtime = Arc::new(RuntimeEnv::default());
+    let state = SessionState::new_with_config_rt(config, runtime);
+    let ctx = SessionContext::new_with_state(state);
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("number", DataType::Float32, false),
+    ]));
+
+    let batches = vec![
+        RecordBatch::try_new(
+            schema.clone(),
+            vec![
+                Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
+                Arc::new(Float32Array::from(vec![1.12, 3.40, 2.33, 9.10, 
6.66])),
+            ],
+        )
+        .unwrap(),
+        RecordBatch::try_new(
+            schema.clone(),
+            vec![
+                Arc::new(Int32Array::from(vec![3, 4, 5])),
+                Arc::new(Float32Array::from(vec![1.11, 2.22, 3.33])),
+            ],
+        )
+        .unwrap(),
+    ];
+    let df = ctx.read_batches(batches).unwrap();
+    df.clone().show().await.unwrap();
+    let result = df.collect().await?;
+    let expected = [
+        "+----+--------+",
+        "| id | number |",
+        "+----+--------+",
+        "| 1  | 1.12   |",
+        "| 2  | 3.4    |",
+        "| 3  | 2.33   |",
+        "| 4  | 9.1    |",
+        "| 5  | 6.66   |",
+        "| 3  | 1.11   |",
+        "| 4  | 2.22   |",
+        "| 5  | 3.33   |",
+        "+----+--------+",
+    ];
+    assert_batches_sorted_eq!(expected, &result);
+    Ok(())
+}
+#[tokio::test]
+async fn test_read_batches_empty() -> Result<()> {
+    let config = SessionConfig::new();
+    let runtime = Arc::new(RuntimeEnv::default());
+    let state = SessionState::new_with_config_rt(config, runtime);
+    let ctx = SessionContext::new_with_state(state);
+
+    let batches = vec![];
+    let df = ctx.read_batches(batches).unwrap();
+    df.clone().show().await.unwrap();
+    let result = df.collect().await?;
+    let expected = ["++", "++"];
+    assert_batches_sorted_eq!(expected, &result);
+    Ok(())
+}
 
 #[tokio::test]
 async fn consecutive_projection_same_schema() -> Result<()> {

Reply via email to