This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ray.git


The following commit(s) were added to refs/heads/main by this push:
     new 4c83d76  fix query 15, update validator, formatting, ray scheduling 
fix (#83)
4c83d76 is described below

commit 4c83d7618e5c6fe44eda0892bd0aa1d2251201f6
Author: robtandy <[email protected]>
AuthorDate: Tue Mar 11 12:03:33 2025 -0400

    fix query 15, update validator, formatting, ray scheduling fix (#83)
---
 datafusion_ray/core.py |   4 +-
 datafusion_ray/util.py |   2 +-
 src/dataframe.rs       |   2 +-
 src/lib.rs             |   2 +-
 src/util.rs            | 185 ++++++++++++++++++++-----------------------------
 tpch/tpcbench.py       |  35 +++++-----
 6 files changed, 99 insertions(+), 131 deletions(-)

diff --git a/datafusion_ray/core.py b/datafusion_ray/core.py
index 0d1736a..4832c9f 100644
--- a/datafusion_ray/core.py
+++ b/datafusion_ray/core.py
@@ -249,7 +249,7 @@ class DFRayProcessorPool:
         log.info("all processors shutdown")
 
 
[email protected](num_cpus=0)
[email protected](num_cpus=0.01, scheduling_strategy="SPREAD")
 class DFRayProcessor:
     def __init__(self, processor_key):
         self.processor_key = processor_key
@@ -317,7 +317,7 @@ class InternalStageData:
         return f"""Stage: {self.stage_id}, pg: {self.partition_group}, 
child_stages:{self.child_stage_ids}, listening addr:{self.remote_addr}"""
 
 
[email protected](num_cpus=0)
[email protected](num_cpus=0.01, scheduling_strategy="SPREAD")
 class DFRayContextSupervisor:
     def __init__(
         self,
diff --git a/datafusion_ray/util.py b/datafusion_ray/util.py
index d3fecf8..bd1b387 100644
--- a/datafusion_ray/util.py
+++ b/datafusion_ray/util.py
@@ -1,4 +1,4 @@
 from datafusion_ray._datafusion_ray_internal import (
-    exec_sql_on_tables,
+    LocalValidator,
     prettify,
 )
diff --git a/src/dataframe.rs b/src/dataframe.rs
index adecf6b..7189c6d 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -50,10 +50,10 @@ use crate::max_rows::MaxRowsExec;
 use crate::pre_fetch::PrefetchExec;
 use crate::stage::DFRayStageExec;
 use crate::stage_reader::DFRayStageReaderExec;
+use crate::util::ResultExt;
 use crate::util::collect_from_stage;
 use crate::util::display_plan_with_partition_counts;
 use crate::util::physical_plan_to_bytes;
-use crate::util::ResultExt;
 
 /// Internal rust class beyind the DFRayDataFrame python object
 ///
diff --git a/src/lib.rs b/src/lib.rs
index 5158f5c..484f122 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -43,8 +43,8 @@ fn _datafusion_ray_internal(m: &Bound<'_, PyModule>) -> 
PyResult<()> {
     m.add_class::<dataframe::DFRayDataFrame>()?;
     m.add_class::<dataframe::PyDFRayStage>()?;
     m.add_class::<processor_service::DFRayProcessorService>()?;
+    m.add_class::<util::LocalValidator>()?;
     m.add_function(wrap_pyfunction!(util::prettify, m)?)?;
-    m.add_function(wrap_pyfunction!(util::exec_sql_on_tables, m)?)?;
     Ok(())
 }
 
diff --git a/src/util.rs b/src/util.rs
index 1fa36e8..a35bdea 100644
--- a/src/util.rs
+++ b/src/util.rs
@@ -8,13 +8,12 @@ use std::task::{Context, Poll};
 use std::time::Duration;
 
 use arrow::array::RecordBatch;
-use arrow::compute::concat_batches;
 use arrow::datatypes::SchemaRef;
 use arrow::error::ArrowError;
 use arrow::ipc::convert::fb_to_schema;
 use arrow::ipc::reader::StreamReader;
 use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
-use arrow::ipc::{root_as_message, MetadataVersion};
+use arrow::ipc::{MetadataVersion, root_as_message};
 use arrow::pyarrow::*;
 use arrow::util::pretty;
 use arrow_flight::{FlightClient, FlightData, Ticket};
@@ -30,16 +29,16 @@ use datafusion::error::DataFusionError;
 use datafusion::execution::object_store::ObjectStoreUrl;
 use datafusion::execution::{RecordBatchStream, SendableRecordBatchStream, 
SessionStateBuilder};
 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
-use datafusion::physical_plan::{displayable, ExecutionPlan, 
ExecutionPlanProperties};
-use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties, 
displayable};
+use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext};
 use datafusion_proto::physical_plan::AsExecutionPlan;
 use datafusion_python::utils::wait_for_future;
 use futures::{Stream, StreamExt};
 use log::debug;
+use object_store::ObjectStore;
 use object_store::aws::AmazonS3Builder;
 use object_store::gcp::GoogleCloudStorageBuilder;
 use object_store::http::HttpBuilder;
-use object_store::ObjectStore;
 use parking_lot::Mutex;
 use pyo3::prelude::*;
 use pyo3::types::{PyBytes, PyList};
@@ -411,62 +410,77 @@ fn print_node(plan: &Arc<dyn ExecutionPlan>, indent: 
usize, output: &mut String)
     }
 }
 
-async fn exec_sql(
-    query: String,
-    tables: Vec<(String, String)>,
-) -> Result<RecordBatch, DataFusionError> {
-    let ctx = SessionContext::new();
-    for (name, path) in tables {
-        let opt =
-            
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(".parquet");
-        debug!("exec_sql: registering table {} at {}", name, path);
+#[pyclass]
+pub struct LocalValidator {
+    ctx: SessionContext,
+}
+
+#[pymethods]
+impl LocalValidator {
+    #[new]
+    fn new() -> Self {
+        let ctx = SessionContext::new();
+        Self { ctx }
+    }
+
+    pub fn register_parquet(&self, py: Python, name: String, path: String) -> 
PyResult<()> {
+        let options = ParquetReadOptions::default();
 
-        let url = ListingTableUrl::parse(&path)?;
+        let url = ListingTableUrl::parse(&path).to_py_err()?;
 
-        maybe_register_object_store(&ctx, url.as_ref())?;
+        maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
+        debug!("register_parquet: registering table {} at {}", name, path);
 
-        ctx.register_listing_table(&name, &path, opt, None, None)
-            .await?;
+        wait_for_future(py, self.ctx.register_parquet(&name, &path, 
options.clone()))?;
+        Ok(())
     }
-    let df = ctx.sql(&query).await?;
-    let schema = df.schema().inner().clone();
-    let batches = df.collect().await?;
-    concat_batches(&schema, batches.iter()).map_err(|e| 
DataFusionError::ArrowError(e, None))
-}
 
-/// Executes a query on the specified tables using DataFusion without Ray.
-///
-/// Returns the query results as a RecordBatch that can be used to verify the
-/// correctness of DataFusion-Ray execution of the same query.
-///
-/// # Arguments
-///
-/// * `py`: the Python token
-/// * `query`: the SQL query string to execute
-/// * `tables`: a list of `(name, url)` tuples specifying the tables to query;
-///   the `url` identifies the parquet files for each listing table and see
-///   [`datafusion::datasource::listing::ListingTableUrl::parse`] for details
-///   of supported URL formats
-///  * `listing`: boolean indicating whether this is a listing table path or 
not
-#[pyfunction]
-#[pyo3(signature = (query, tables, listing=false))]
-pub fn exec_sql_on_tables(
-    py: Python,
-    query: String,
-    tables: Bound<'_, PyList>,
-    listing: bool,
-) -> PyResult<PyObject> {
-    let table_vec = {
-        let mut v = Vec::with_capacity(tables.len());
-        for entry in tables.iter() {
-            let (name, path) = entry.extract::<(String, String)>()?;
-            let path = if listing { format!("{path}/") } else { path };
-            v.push((name, path));
-        }
-        v
-    };
-    let batch = wait_for_future(py, exec_sql(query, table_vec))?;
-    batch.to_pyarrow(py)
+    #[pyo3(signature = (name, path, file_extension=".parquet"))]
+    pub fn register_listing_table(
+        &mut self,
+        py: Python,
+        name: &str,
+        path: &str,
+        file_extension: &str,
+    ) -> PyResult<()> {
+        let options =
+            
ListingOptions::new(Arc::new(ParquetFormat::new())).with_file_extension(file_extension);
+
+        let path = format!("{path}/");
+        let url = ListingTableUrl::parse(&path).to_py_err()?;
+
+        maybe_register_object_store(&self.ctx, url.as_ref()).to_py_err()?;
+
+        debug!(
+            "register_listing_table: registering table {} at {}",
+            name, path
+        );
+        wait_for_future(
+            py,
+            self.ctx
+                .register_listing_table(name, path, options, None, None),
+        )
+        .to_py_err()
+    }
+
+    #[pyo3(signature = (query))]
+    fn collect_sql(&self, py: Python, query: String) -> PyResult<PyObject> {
+        let fut = async || {
+            let df = self.ctx.sql(&query).await?;
+            let batches = df.collect().await?;
+
+            Ok::<_, DataFusionError>(batches)
+        };
+
+        let batches = wait_for_future(py, fut())
+            .to_py_err()?
+            .iter()
+            .map(|batch| batch.to_pyarrow(py))
+            .collect::<PyResult<Vec<_>>>()?;
+
+        let pylist = PyList::new(py, batches)?;
+        Ok(pylist.into())
+    }
 }
 
 pub(crate) fn register_object_store_for_paths_in_plan(
@@ -570,62 +584,14 @@ mod test {
     use std::{sync::Arc, vec};
 
     use arrow::{
-        array::{Int32Array, StringArray},
+        array::Int32Array,
         datatypes::{DataType, Field, Schema},
     };
-    use datafusion::{
-        parquet::file::properties::WriterProperties, 
test_util::parquet::TestParquetFile,
-    };
+    
     use futures::stream;
 
     use super::*;
 
-    #[tokio::test]
-    async fn test_exec_sql() {
-        let dir = tempfile::tempdir().unwrap();
-        let path = dir.path().join("people.parquet");
-
-        let batch = RecordBatch::try_new(
-            Arc::new(Schema::new(vec![
-                Field::new("age", DataType::Int32, false),
-                Field::new("name", DataType::Utf8, false),
-            ])),
-            vec![
-                Arc::new(Int32Array::from(vec![11, 12, 13])),
-                Arc::new(StringArray::from(vec!["alice", "bob", "cindy"])),
-            ],
-        )
-        .unwrap();
-        let props = WriterProperties::builder().build();
-        let file = TestParquetFile::try_new(path.clone(), props, 
Some(batch.clone())).unwrap();
-
-        // test with file
-        let tables = vec![(
-            "people".to_string(),
-            format!("file://{}", file.path().to_str().unwrap()),
-        )];
-        let query = "SELECT * FROM people ORDER BY age".to_string();
-        let res = exec_sql(query.clone(), tables).await.unwrap();
-        assert_eq!(
-            format!(
-                "{}",
-                pretty::pretty_format_batches(&[batch.clone()]).unwrap()
-            ),
-            format!("{}", pretty::pretty_format_batches(&[res]).unwrap()),
-        );
-
-        // test with dir
-        let tables = vec![(
-            "people".to_string(),
-            format!("file://{}/", dir.path().to_str().unwrap()),
-        )];
-        let res = exec_sql(query, tables).await.unwrap();
-        assert_eq!(
-            format!("{}", pretty::pretty_format_batches(&[batch]).unwrap()),
-            format!("{}", pretty::pretty_format_batches(&[res]).unwrap()),
-        );
-    }
-
     #[test]
     fn test_ipc_roundtrip() {
         let batch = RecordBatch::try_new(
@@ -641,10 +607,9 @@ mod test {
     #[tokio::test]
     async fn test_max_rows_stream() {
         let schema = Arc::new(Schema::new(vec![Field::new("a", 
DataType::Int32, false)]));
-        let batch = RecordBatch::try_new(
-            schema.clone(),
-            vec![Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
-        )
+        let batch = RecordBatch::try_new(schema.clone(), 
vec![Arc::new(Int32Array::from(vec![
+            1, 2, 3, 4, 5, 6, 7, 8,
+        ]))])
         .unwrap();
 
         // 24 total rows
diff --git a/tpch/tpcbench.py b/tpch/tpcbench.py
index dd6df1e..16288fb 100644
--- a/tpch/tpcbench.py
+++ b/tpch/tpcbench.py
@@ -18,7 +18,7 @@
 import argparse
 import ray
 from datafusion_ray import DFRayContext, df_ray_runtime_env
-from datafusion_ray.util import exec_sql_on_tables, prettify
+from datafusion_ray.util import LocalValidator, prettify
 from datetime import datetime
 import json
 import os
@@ -63,6 +63,8 @@ def main(
         worker_pool_min=worker_pool_min,
     )
 
+    local = LocalValidator()
+
     ctx.set("datafusion.execution.target_partitions", f"{concurrency}")
     # ctx.set("datafusion.execution.parquet.pushdown_filters", "true")
     ctx.set("datafusion.optimizer.enable_round_robin_repartition", "false")
@@ -73,8 +75,10 @@ def main(
         print(f"Registering table {table} using path {path}")
         if listing_tables:
             ctx.register_listing_table(table, path)
+            local.register_listing_table(table, path)
         else:
             ctx.register_parquet(table, path)
+            local.register_parquet(table, path)
 
     current_time_millis = int(datetime.now().timestamp() * 1000)
     results_path = f"datafusion-ray-tpch-{current_time_millis}.json"
@@ -99,28 +103,27 @@ def main(
     for qnum in queries:
         sql = tpch_query(qnum)
 
-        statements = sql.split(";")
-        sql = statements[0]
-
-        print("executing ", sql)
+        statements = list(
+            filter(lambda x: len(x) > 0, map(lambda x: x.strip(), 
sql.split(";")))
+        )
+        print(f"statements = {statements}")
 
         start_time = time.time()
-        df = ctx.sql(sql)
-        batches = df.collect()
+        all_batches = []
+        for sql in statements:
+            print("executing ", sql)
+            df = ctx.sql(sql)
+            all_batches.append(df.collect())
         end_time = time.time()
         results["queries"][qnum] = end_time - start_time
 
-        calculated = prettify(batches)
+        calculated = "\n".join([prettify(b) for b in all_batches])
         print(calculated)
         if validate:
-            tables = [
-                (name, os.path.join(data_path, f"{name}.parquet"))
-                for name in table_names
-            ]
-            answer_batches = [
-                b for b in [exec_sql_on_tables(sql, tables, listing_tables)] 
if b
-            ]
-            expected = prettify(answer_batches)
+            all_batches = []
+            for sql in statements:
+                all_batches.append(local.collect_sql(sql))
+            expected = "\n".join([prettify(b) for b in all_batches])
 
             results["validated"][qnum] = calculated == expected
         print(f"done with query {qnum}")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to