This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ray.git
The following commit(s) were added to refs/heads/main by this push:
new 2783613 chore: Make query stage / shuffle code easier to understand
(#54)
2783613 is described below
commit 27836132c97b3324b4ed5969ac9fd08751fbc8af
Author: Andy Grove <[email protected]>
AuthorDate: Wed Dec 18 23:27:36 2024 -0700
chore: Make query stage / shuffle code easier to understand (#54)
---
datafusion_ray/context.py | 7 +++---
src/planner.rs | 2 +-
src/query_stage.rs | 42 +++++++++++++++------------------
src/shuffle/codec.rs | 2 +-
src/shuffle/writer.rs | 10 ++++----
testdata/expected-plans/q1.txt | 2 +-
testdata/expected-plans/q10.txt | 2 +-
testdata/expected-plans/q11.txt | 2 +-
testdata/expected-plans/q12.txt | 2 +-
testdata/expected-plans/q13.txt | 2 +-
testdata/expected-plans/q16.txt | 2 +-
testdata/expected-plans/q18.txt | 2 +-
testdata/expected-plans/q2.txt | 2 +-
testdata/expected-plans/q20.txt | 2 +-
testdata/expected-plans/q21.txt | 2 +-
testdata/expected-plans/q22.txt | 2 +-
testdata/expected-plans/q3.txt | 2 +-
testdata/expected-plans/q4.txt | 2 +-
testdata/expected-plans/q5.txt | 2 +-
testdata/expected-plans/q7.txt | 2 +-
testdata/expected-plans/q8.txt | 2 +-
testdata/expected-plans/q9.txt | 2 +-
tests/test_context.py | 52 ++++++++++++++++++++---------------------
23 files changed, 72 insertions(+), 77 deletions(-)
diff --git a/datafusion_ray/context.py b/datafusion_ray/context.py
index 0070220..8d354ff 100644
--- a/datafusion_ray/context.py
+++ b/datafusion_ray/context.py
@@ -50,7 +50,7 @@ def execute_query_stage(
# if the query stage has a single output partition then we need to execute
for the output
# partition, otherwise we need to execute in parallel for each input
partition
- concurrency = stage.get_input_partition_count()
+ concurrency = stage.get_execution_partition_count()
output_partitions_count = stage.get_output_partition_count()
if output_partitions_count == 1:
# reduce stage
@@ -159,5 +159,6 @@ class DatafusionRayContext:
)
_, partitions = ray.get(future)
# assert len(partitions) == 1, len(partitions)
- result_set = ray.get(partitions[0])
- return result_set
+ record_batches = ray.get(partitions[0])
+ # filter out empty batches
+ return [batch for batch in record_batches if batch.num_rows > 0]
diff --git a/src/planner.rs b/src/planner.rs
index 954d8e2..c1e7b41 100644
--- a/src/planner.rs
+++ b/src/planner.rs
@@ -399,7 +399,7 @@ mod test {
let query_stage = graph.query_stages.get(&id).unwrap();
output.push_str(&format!(
"Query Stage #{id} ({} -> {}):\n{}\n",
- query_stage.get_input_partition_count(),
+ query_stage.get_execution_partition_count(),
query_stage.get_output_partition_count(),
displayable(query_stage.plan.as_ref()).indent(false)
));
diff --git a/src/query_stage.rs b/src/query_stage.rs
index 05c090b..a5c9a08 100644
--- a/src/query_stage.rs
+++ b/src/query_stage.rs
@@ -16,7 +16,7 @@
// under the License.
use crate::context::serialize_execution_plan;
-use crate::shuffle::{ShuffleCodec, ShuffleReaderExec};
+use crate::shuffle::{ShuffleCodec, ShuffleReaderExec, ShuffleWriterExec};
use datafusion::error::Result;
use datafusion::physical_plan::{ExecutionPlan, ExecutionPlanProperties,
Partitioning};
use datafusion::prelude::SessionContext;
@@ -60,8 +60,8 @@ impl PyQueryStage {
self.stage.get_child_stage_ids()
}
- pub fn get_input_partition_count(&self) -> usize {
- self.stage.get_input_partition_count()
+ pub fn get_execution_partition_count(&self) -> usize {
+ self.stage.get_execution_partition_count()
}
pub fn get_output_partition_count(&self) -> usize {
@@ -75,16 +75,6 @@ pub struct QueryStage {
pub plan: Arc<dyn ExecutionPlan>,
}
-fn _get_output_partition_count(plan: &dyn ExecutionPlan) -> usize {
- // UnknownPartitioning and HashPartitioning with empty expressions will
- // both return 1 partition.
- match plan.properties().output_partitioning() {
- Partitioning::UnknownPartitioning(_) => 1,
- Partitioning::Hash(expr, _) if expr.is_empty() => 1,
- p => p.partition_count(),
- }
-}
-
impl QueryStage {
pub fn new(id: usize, plan: Arc<dyn ExecutionPlan>) -> Self {
Self { id, plan }
@@ -96,21 +86,27 @@ impl QueryStage {
ids
}
- /// Get the input partition count. This is the same as the number of
concurrent tasks
- /// when we schedule this query stage for execution
- pub fn get_input_partition_count(&self) -> usize {
- if self.plan.children().is_empty() {
- // leaf node (file scan)
- self.plan.output_partitioning().partition_count()
+ /// Get the number of partitions that can be executed in parallel
+ pub fn get_execution_partition_count(&self) -> usize {
+ if let Some(shuffle) =
self.plan.as_any().downcast_ref::<ShuffleWriterExec>() {
+ // use the partitioning of the input to the shuffle write because
we are
+ // really executing that and then using the shuffle writer to
repartition
+ // the output
+ shuffle.input_plan.output_partitioning().partition_count()
} else {
- self.plan.children()[0]
- .output_partitioning()
- .partition_count()
+ // for any other plan, use its output partitioning
+ self.plan.output_partitioning().partition_count()
}
}
pub fn get_output_partition_count(&self) -> usize {
- _get_output_partition_count(self.plan.as_ref())
+ // UnknownPartitioning and HashPartitioning with empty expressions will
+ // both return 1 partition.
+ match self.plan.properties().output_partitioning() {
+ Partitioning::UnknownPartitioning(_) => 1,
+ Partitioning::Hash(expr, _) if expr.is_empty() => 1,
+ p => p.partition_count(),
+ }
}
}
diff --git a/src/shuffle/codec.rs b/src/shuffle/codec.rs
index 79af0b8..0420428 100644
--- a/src/shuffle/codec.rs
+++ b/src/shuffle/codec.rs
@@ -102,7 +102,7 @@ impl PhysicalExtensionCodec for ShuffleCodec {
};
PlanType::ShuffleReader(reader)
} else if let Some(writer) =
node.as_any().downcast_ref::<ShuffleWriterExec>() {
- let plan =
PhysicalPlanNode::try_from_physical_plan(writer.plan.clone(), self)?;
+ let plan =
PhysicalPlanNode::try_from_physical_plan(writer.input_plan.clone(), self)?;
let partitioning =
encode_partitioning_scheme(writer.properties().output_partitioning())?;
let writer = ShuffleWriterExecNode {
diff --git a/src/shuffle/writer.rs b/src/shuffle/writer.rs
index 069f99d..0e0f984 100644
--- a/src/shuffle/writer.rs
+++ b/src/shuffle/writer.rs
@@ -47,7 +47,7 @@ use std::sync::Arc;
#[derive(Debug)]
pub struct ShuffleWriterExec {
pub stage_id: usize,
- pub(crate) plan: Arc<dyn ExecutionPlan>,
+ pub(crate) input_plan: Arc<dyn ExecutionPlan>,
/// Output partitioning
properties: PlanProperties,
/// Directory to write shuffle files from
@@ -84,7 +84,7 @@ impl ShuffleWriterExec {
Self {
stage_id,
- plan,
+ input_plan: plan,
properties,
shuffle_dir: shuffle_dir.to_string(),
metrics: ExecutionPlanMetricsSet::new(),
@@ -98,11 +98,11 @@ impl ExecutionPlan for ShuffleWriterExec {
}
fn schema(&self) -> SchemaRef {
- self.plan.schema()
+ self.input_plan.schema()
}
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
- vec![&self.plan]
+ vec![&self.input_plan]
}
fn with_new_children(
@@ -122,7 +122,7 @@ impl ExecutionPlan for ShuffleWriterExec {
self.stage_id
);
- let mut stream = self.plan.execute(input_partition, context)?;
+ let mut stream = self.input_plan.execute(input_partition, context)?;
let write_time =
MetricBuilder::new(&self.metrics).subset_time("write_time",
input_partition);
let repart_time =
diff --git a/testdata/expected-plans/q1.txt b/testdata/expected-plans/q1.txt
index 282d5da..6f78394 100644
--- a/testdata/expected-plans/q1.txt
+++ b/testdata/expected-plans/q1.txt
@@ -42,7 +42,7 @@ ShuffleWriterExec(stage_id=1,
output_partitioning=Hash([Column { name: "l_return
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=0, input_partitioning=Hash([Column {
name: "l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }],
2))
-Query Stage #2 (2 -> 1):
+Query Stage #2 (1 -> 1):
SortPreservingMergeExec: [l_returnflag@0 ASC NULLS LAST, l_linestatus@1 ASC
NULLS LAST]
ShuffleReaderExec(stage_id=1, input_partitioning=Hash([Column { name:
"l_returnflag", index: 0 }, Column { name: "l_linestatus", index: 1 }], 2))
diff --git a/testdata/expected-plans/q10.txt b/testdata/expected-plans/q10.txt
index 046f69e..3825561 100644
--- a/testdata/expected-plans/q10.txt
+++ b/testdata/expected-plans/q10.txt
@@ -117,7 +117,7 @@ ShuffleWriterExec(stage_id=7,
output_partitioning=Hash([Column { name: "c_custke
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column {
name: "c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column {
name: "c_acctbal", index: 2 }, Column { name: "c_phone", index: 3 }, Column {
name: "n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column {
name: "c_comment", index: 6 }], 2))
-Query Stage #8 (2 -> 1):
+Query Stage #8 (1 -> 1):
SortPreservingMergeExec: [revenue@2 DESC], fetch=20
ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column { name:
"c_custkey", index: 0 }, Column { name: "c_name", index: 1 }, Column { name:
"c_acctbal", index: 3 }, Column { name: "c_phone", index: 6 }, Column { name:
"n_name", index: 4 }, Column { name: "c_address", index: 5 }, Column { name:
"c_comment", index: 7 }], 2))
diff --git a/testdata/expected-plans/q11.txt b/testdata/expected-plans/q11.txt
index 74f74d7..2972d52 100644
--- a/testdata/expected-plans/q11.txt
+++ b/testdata/expected-plans/q11.txt
@@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=10,
output_partitioning=Hash([Column { name: "ps_part
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column {
name: "ps_partkey", index: 0 }], 2))
-Query Stage #11 (2 -> 1):
+Query Stage #11 (1 -> 1):
SortPreservingMergeExec: [value@1 DESC]
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name:
"ps_partkey", index: 0 }], 2))
diff --git a/testdata/expected-plans/q12.txt b/testdata/expected-plans/q12.txt
index c7ae269..4cf0596 100644
--- a/testdata/expected-plans/q12.txt
+++ b/testdata/expected-plans/q12.txt
@@ -65,7 +65,7 @@ ShuffleWriterExec(stage_id=3,
output_partitioning=Hash([Column { name: "l_shipmo
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column {
name: "l_shipmode", index: 0 }], 2))
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
SortPreservingMergeExec: [l_shipmode@0 ASC NULLS LAST]
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name:
"l_shipmode", index: 0 }], 2))
diff --git a/testdata/expected-plans/q13.txt b/testdata/expected-plans/q13.txt
index 366db12..da7e93a 100644
--- a/testdata/expected-plans/q13.txt
+++ b/testdata/expected-plans/q13.txt
@@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3,
output_partitioning=Hash([Column { name: "c_count"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column {
name: "c_count", index: 0 }], 2))
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
SortPreservingMergeExec: [custdist@1 DESC, c_count@0 DESC]
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name:
"c_count", index: 0 }], 2))
diff --git a/testdata/expected-plans/q16.txt b/testdata/expected-plans/q16.txt
index 24ecb18..b26e9a4 100644
--- a/testdata/expected-plans/q16.txt
+++ b/testdata/expected-plans/q16.txt
@@ -107,7 +107,7 @@ ShuffleWriterExec(stage_id=6,
output_partitioning=Hash([Column { name: "p_brand"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column {
name: "p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column {
name: "p_size", index: 2 }], 2))
-Query Stage #7 (2 -> 1):
+Query Stage #7 (1 -> 1):
SortPreservingMergeExec: [supplier_cnt@3 DESC, p_brand@0 ASC NULLS LAST,
p_type@1 ASC NULLS LAST, p_size@2 ASC NULLS LAST]
ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name:
"p_brand", index: 0 }, Column { name: "p_type", index: 1 }, Column { name:
"p_size", index: 2 }], 2))
diff --git a/testdata/expected-plans/q18.txt b/testdata/expected-plans/q18.txt
index 30179d0..a5d28e8 100644
--- a/testdata/expected-plans/q18.txt
+++ b/testdata/expected-plans/q18.txt
@@ -104,7 +104,7 @@ ShuffleWriterExec(stage_id=6,
output_partitioning=Hash([Column { name: "c_name",
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name:
"c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name:
"o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column {
name: "o_totalprice", index: 4 }], 2))
-Query Stage #7 (2 -> 1):
+Query Stage #7 (1 -> 1):
SortPreservingMergeExec: [o_totalprice@4 DESC, o_orderdate@3 ASC NULLS LAST],
fetch=100
ShuffleReaderExec(stage_id=6, input_partitioning=Hash([Column { name:
"c_name", index: 0 }, Column { name: "c_custkey", index: 1 }, Column { name:
"o_orderkey", index: 2 }, Column { name: "o_orderdate", index: 3 }, Column {
name: "o_totalprice", index: 4 }], 2))
diff --git a/testdata/expected-plans/q2.txt b/testdata/expected-plans/q2.txt
index bc0713c..9778441 100644
--- a/testdata/expected-plans/q2.txt
+++ b/testdata/expected-plans/q2.txt
@@ -252,7 +252,7 @@ ShuffleWriterExec(stage_id=17,
output_partitioning=Hash([Column { name: "p_partk
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=16, input_partitioning=Hash([Column {
name: "ps_partkey", index: 1 }, Column { name: "min(partsupp.ps_supplycost)",
index: 0 }], 2))
-Query Stage #18 (2 -> 1):
+Query Stage #18 (1 -> 1):
SortPreservingMergeExec: [s_acctbal@0 DESC, n_name@2 ASC NULLS LAST, s_name@1
ASC NULLS LAST, p_partkey@3 ASC NULLS LAST], fetch=100
ShuffleReaderExec(stage_id=17, input_partitioning=Hash([Column { name:
"p_partkey", index: 3 }], 2))
diff --git a/testdata/expected-plans/q20.txt b/testdata/expected-plans/q20.txt
index 13b21c8..e1bc54c 100644
--- a/testdata/expected-plans/q20.txt
+++ b/testdata/expected-plans/q20.txt
@@ -142,7 +142,7 @@ ShuffleWriterExec(stage_id=8, output_partitioning=Hash([],
2))
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=7, input_partitioning=Hash([Column {
name: "ps_suppkey", index: 0 }], 2))
-Query Stage #9 (2 -> 1):
+Query Stage #9 (1 -> 1):
SortPreservingMergeExec: [s_name@0 ASC NULLS LAST]
ShuffleReaderExec(stage_id=8, input_partitioning=Hash([], 2))
diff --git a/testdata/expected-plans/q21.txt b/testdata/expected-plans/q21.txt
index b88bccc..8d6798f 100644
--- a/testdata/expected-plans/q21.txt
+++ b/testdata/expected-plans/q21.txt
@@ -172,7 +172,7 @@ ShuffleWriterExec(stage_id=10,
output_partitioning=Hash([Column { name: "s_name"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=9, input_partitioning=Hash([Column {
name: "s_name", index: 0 }], 2))
-Query Stage #11 (2 -> 1):
+Query Stage #11 (1 -> 1):
SortPreservingMergeExec: [numwait@1 DESC, s_name@0 ASC NULLS LAST], fetch=100
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column { name:
"s_name", index: 0 }], 2))
diff --git a/testdata/expected-plans/q22.txt b/testdata/expected-plans/q22.txt
index da693fb..7ad4ae1 100644
--- a/testdata/expected-plans/q22.txt
+++ b/testdata/expected-plans/q22.txt
@@ -91,7 +91,7 @@ ShuffleWriterExec(stage_id=4,
output_partitioning=Hash([Column { name: "cntrycod
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column {
name: "cntrycode", index: 0 }], 2))
-Query Stage #5 (2 -> 1):
+Query Stage #5 (1 -> 1):
SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST]
ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column { name:
"cntrycode", index: 0 }], 2))
diff --git a/testdata/expected-plans/q3.txt b/testdata/expected-plans/q3.txt
index f9039d3..3af2ea0 100644
--- a/testdata/expected-plans/q3.txt
+++ b/testdata/expected-plans/q3.txt
@@ -97,7 +97,7 @@ ShuffleWriterExec(stage_id=5,
output_partitioning=Hash([Column { name: "l_orderk
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=4, input_partitioning=Hash([Column {
name: "l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 1 },
Column { name: "o_shippriority", index: 2 }], 2))
-Query Stage #6 (2 -> 1):
+Query Stage #6 (1 -> 1):
SortPreservingMergeExec: [revenue@1 DESC, o_orderdate@2 ASC NULLS LAST],
fetch=10
ShuffleReaderExec(stage_id=5, input_partitioning=Hash([Column { name:
"l_orderkey", index: 0 }, Column { name: "o_orderdate", index: 2 }, Column {
name: "o_shippriority", index: 3 }], 2))
diff --git a/testdata/expected-plans/q4.txt b/testdata/expected-plans/q4.txt
index 20460e4..2504483 100644
--- a/testdata/expected-plans/q4.txt
+++ b/testdata/expected-plans/q4.txt
@@ -70,7 +70,7 @@ ShuffleWriterExec(stage_id=3,
output_partitioning=Hash([Column { name: "o_orderp
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=2, input_partitioning=Hash([Column {
name: "o_orderpriority", index: 0 }], 2))
-Query Stage #4 (2 -> 1):
+Query Stage #4 (1 -> 1):
SortPreservingMergeExec: [o_orderpriority@0 ASC NULLS LAST]
ShuffleReaderExec(stage_id=3, input_partitioning=Hash([Column { name:
"o_orderpriority", index: 0 }], 2))
diff --git a/testdata/expected-plans/q5.txt b/testdata/expected-plans/q5.txt
index 2bacb27..3e66ddb 100644
--- a/testdata/expected-plans/q5.txt
+++ b/testdata/expected-plans/q5.txt
@@ -167,7 +167,7 @@ ShuffleWriterExec(stage_id=11,
output_partitioning=Hash([Column { name: "n_name"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column {
name: "n_name", index: 0 }], 2))
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
SortPreservingMergeExec: [revenue@1 DESC]
ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name:
"n_name", index: 0 }], 2))
diff --git a/testdata/expected-plans/q7.txt b/testdata/expected-plans/q7.txt
index 43bc031..9321b1b 100644
--- a/testdata/expected-plans/q7.txt
+++ b/testdata/expected-plans/q7.txt
@@ -176,7 +176,7 @@ ShuffleWriterExec(stage_id=11,
output_partitioning=Hash([Column { name: "supp_na
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column {
name: "supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 },
Column { name: "l_year", index: 2 }], 2))
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
SortPreservingMergeExec: [supp_nation@0 ASC NULLS LAST, cust_nation@1 ASC
NULLS LAST, l_year@2 ASC NULLS LAST]
ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name:
"supp_nation", index: 0 }, Column { name: "cust_nation", index: 1 }, Column {
name: "l_year", index: 2 }], 2))
diff --git a/testdata/expected-plans/q8.txt b/testdata/expected-plans/q8.txt
index e9f5b91..c7ec1ec 100644
--- a/testdata/expected-plans/q8.txt
+++ b/testdata/expected-plans/q8.txt
@@ -230,7 +230,7 @@ ShuffleWriterExec(stage_id=15,
output_partitioning=Hash([Column { name: "o_year"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=14, input_partitioning=Hash([Column {
name: "o_year", index: 0 }], 2))
-Query Stage #16 (2 -> 1):
+Query Stage #16 (1 -> 1):
SortPreservingMergeExec: [o_year@0 ASC NULLS LAST]
ShuffleReaderExec(stage_id=15, input_partitioning=Hash([Column { name:
"o_year", index: 0 }], 2))
diff --git a/testdata/expected-plans/q9.txt b/testdata/expected-plans/q9.txt
index 2c713b3..fa087f1 100644
--- a/testdata/expected-plans/q9.txt
+++ b/testdata/expected-plans/q9.txt
@@ -166,7 +166,7 @@ ShuffleWriterExec(stage_id=11,
output_partitioning=Hash([Column { name: "nation"
CoalesceBatchesExec: target_batch_size=8192
ShuffleReaderExec(stage_id=10, input_partitioning=Hash([Column {
name: "nation", index: 0 }, Column { name: "o_year", index: 1 }], 2))
-Query Stage #12 (2 -> 1):
+Query Stage #12 (1 -> 1):
SortPreservingMergeExec: [nation@0 ASC NULLS LAST, o_year@1 DESC]
ShuffleReaderExec(stage_id=11, input_partitioning=Hash([Column { name:
"nation", index: 0 }, Column { name: "o_year", index: 1 }], 2))
diff --git a/tests/test_context.py b/tests/test_context.py
index ecc3324..602f761 100644
--- a/tests/test_context.py
+++ b/tests/test_context.py
@@ -17,42 +17,42 @@
from datafusion_ray.context import DatafusionRayContext
from datafusion import SessionContext, SessionConfig, RuntimeConfig, col, lit,
functions as F
+import pytest
[email protected]
+def df_ctx():
+ """Fixture to create a DataFusion context."""
+ # used fixed partition count so that tests are deterministic on different
environments
+ config = SessionConfig().with_target_partitions(4)
+ return SessionContext(config=config)
-def test_basic_query_succeed():
- df_ctx = SessionContext()
- ctx = DatafusionRayContext(df_ctx)
[email protected]
+def ctx(df_ctx):
+ """Fixture to create a Datafusion Ray context."""
+ return DatafusionRayContext(df_ctx)
+
+def test_basic_query_succeed(df_ctx, ctx):
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
- # TODO why does this return a single batch and not a list of batches?
record_batches = ctx.sql("SELECT * FROM tips")
- assert record_batches[0].num_rows == 244
+ assert len(record_batches) <= 4
+ num_rows = sum(batch.num_rows for batch in record_batches)
+ assert num_rows == 244
-def test_aggregate_csv():
- df_ctx = SessionContext()
- ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_csv(df_ctx, ctx):
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as
tip_pct from tips group by sex, smoker")
- assert isinstance(record_batches, list)
- # TODO why does this return many empty batches?
- num_rows = 0
- for record_batch in record_batches:
- num_rows += record_batch.num_rows
+ assert len(record_batches) <= 4
+ num_rows = sum(batch.num_rows for batch in record_batches)
assert num_rows == 4
-def test_aggregate_parquet():
- df_ctx = SessionContext()
- ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_parquet(df_ctx, ctx):
df_ctx.register_parquet("tips", "examples/tips.parquet")
record_batches = ctx.sql("select sex, smoker, avg(tip/total_bill) as
tip_pct from tips group by sex, smoker")
- # TODO why does this return many empty batches?
- num_rows = 0
- for record_batch in record_batches:
- num_rows += record_batch.num_rows
+ assert len(record_batches) <= 4
+ num_rows = sum(batch.num_rows for batch in record_batches)
assert num_rows == 4
-def test_aggregate_parquet_dataframe():
- df_ctx = SessionContext()
- ray_ctx = DatafusionRayContext(df_ctx)
+def test_aggregate_parquet_dataframe(df_ctx, ctx):
df = df_ctx.read_parquet(f"examples/tips.parquet")
df = (
df.aggregate(
@@ -62,12 +62,10 @@ def test_aggregate_parquet_dataframe():
.filter(col("day") != lit("Dinner"))
.aggregate([col("sex"), col("smoker")],
[F.avg(col("tip_pct")).alias("avg_pct")])
)
- ray_results = ray_ctx.plan(df.execution_plan())
+ ray_results = ctx.plan(df.execution_plan())
df_ctx.create_dataframe([ray_results]).show()
-def test_no_result_query():
- df_ctx = SessionContext()
- ctx = DatafusionRayContext(df_ctx)
+def test_no_result_query(df_ctx, ctx):
df_ctx.register_csv("tips", "examples/tips.csv", has_header=True)
ctx.sql("CREATE VIEW tips_view AS SELECT * FROM tips")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]