This is an automated email from the ASF dual-hosted git repository.
mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5ee524ec70 feat(substrait): replace SessionContext with a trait
(#13343)
5ee524ec70 is described below
commit 5ee524ec70b1f10a078caca62954ce37b2dc3cc6
Author: Filippo Rossi <[email protected]>
AuthorDate: Wed Nov 20 18:11:18 2024 +0100
feat(substrait): replace SessionContext with a trait (#13343)
* feat(substrait): replace SessionContext with SessionState
* feat(substrait): add logical plan context
* chore(substrait): add apache header
* docs: fix code in docs
* docs(substrait): rename and document context
* chore(substrait): context -> state
* chore: fmt
---
datafusion/core/src/execution/session_state.rs | 4 +-
datafusion/substrait/Cargo.toml | 1 +
datafusion/substrait/src/lib.rs | 4 +-
datafusion/substrait/src/logical_plan/consumer.rs | 286 ++++++++++++---------
datafusion/substrait/src/logical_plan/mod.rs | 1 +
datafusion/substrait/src/logical_plan/producer.rs | 196 +++++++-------
datafusion/substrait/src/logical_plan/state.rs | 63 +++++
datafusion/substrait/src/serializer.rs | 2 +-
.../substrait/tests/cases/consumer_integration.rs | 2 +-
.../substrait/tests/cases/emit_kind_tests.rs | 12 +-
datafusion/substrait/tests/cases/function_test.rs | 2 +-
datafusion/substrait/tests/cases/logical_plans.rs | 6 +-
.../tests/cases/roundtrip_logical_plan.rs | 40 +--
datafusion/substrait/tests/cases/serialize.rs | 12 +-
.../substrait/tests/cases/substrait_validations.rs | 10 +-
15 files changed, 379 insertions(+), 262 deletions(-)
diff --git a/datafusion/core/src/execution/session_state.rs
b/datafusion/core/src/execution/session_state.rs
index 9fc081dd53..e99cf82223 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -296,7 +296,9 @@ impl SessionState {
.resolve(&catalog.default_catalog, &catalog.default_schema)
}
- pub(crate) fn schema_for_ref(
+ /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if
it
+ /// esists.
+ pub fn schema_for_ref(
&self,
table_ref: impl Into<TableReference>,
) -> datafusion_common::Result<Arc<dyn SchemaProvider>> {
diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
index 192fe26d6c..61cdf3e91e 100644
--- a/datafusion/substrait/Cargo.toml
+++ b/datafusion/substrait/Cargo.toml
@@ -34,6 +34,7 @@ workspace = true
[dependencies]
arrow-buffer = { workspace = true }
async-recursion = "1.0"
+async-trait = { workspace = true }
chrono = { workspace = true }
datafusion = { workspace = true, default-features = true }
itertools = { workspace = true }
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
index a6f7c033f9..1389cac75b 100644
--- a/datafusion/substrait/src/lib.rs
+++ b/datafusion/substrait/src/lib.rs
@@ -64,10 +64,10 @@
//! let plan = df.into_optimized_plan()?;
//!
//! // Convert the plan into a substrait (protobuf) Plan
-//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan,
&ctx)?;
+//! let substrait_plan = logical_plan::producer::to_substrait_plan(&plan,
&ctx.state())?;
//!
//! // Receive a substrait protobuf from somewhere, and turn it into a
LogicalPlan
-//! let logical_round_trip =
logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?;
+//! let logical_round_trip =
logical_plan::consumer::from_substrait_plan(&ctx.state(),
&substrait_plan).await?;
//! let logical_round_trip = ctx.state().optimize(&logical_round_trip)?;
//! assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
//! # Ok(())
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs
b/datafusion/substrait/src/logical_plan/consumer.rs
index 1cce228527..77e9eb81f5 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -26,7 +26,7 @@ use datafusion::common::{
not_impl_err, plan_datafusion_err, substrait_datafusion_err,
substrait_err, DFSchema,
DFSchemaRef,
};
-use datafusion::execution::FunctionRegistry;
+use datafusion::datasource::provider_as_source;
use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
use datafusion::logical_expr::{
@@ -56,7 +56,6 @@ use crate::variation_const::{
use datafusion::arrow::array::{new_empty_array, AsArray};
use datafusion::arrow::temporal_conversions::NANOSECONDS;
use datafusion::common::scalar::ScalarStructBuilder;
-use datafusion::dataframe::DataFrame;
use datafusion::logical_expr::builder::project;
use datafusion::logical_expr::expr::InList;
use datafusion::logical_expr::{
@@ -66,9 +65,7 @@ use datafusion::logical_expr::{
use datafusion::prelude::JoinType;
use datafusion::sql::TableReference;
use datafusion::{
- error::Result,
- logical_expr::utils::split_conjunction,
- prelude::{Column, SessionContext},
+ error::Result, logical_expr::utils::split_conjunction, prelude::Column,
scalar::ScalarValue,
};
use std::collections::HashSet;
@@ -102,6 +99,8 @@ use substrait::proto::{
};
use substrait::proto::{ExtendedExpression, FunctionArgument, SortField};
+use super::state::SubstraitPlanningState;
+
// Substrait PrecisionTimestampTz indicates that the timestamp is relative to
UTC, which
// is the same as the expectation for any non-empty timezone in DF, so any
non-empty timezone
// results in correct points on the timeline, and we pick UTC as a reasonable
default.
@@ -203,15 +202,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
async fn union_rels(
rels: &[Rel],
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
let mut union_builder = Ok(LogicalPlanBuilder::from(
- from_substrait_rel(ctx, &rels[0], extensions).await?,
+ from_substrait_rel(state, &rels[0], extensions).await?,
));
for input in &rels[1..] {
- let rel_plan = from_substrait_rel(ctx, input, extensions).await?;
+ let rel_plan = from_substrait_rel(state, input, extensions).await?;
union_builder = if is_all {
union_builder?.union(rel_plan)
@@ -224,16 +223,16 @@ async fn union_rels(
async fn intersect_rels(
rels: &[Rel],
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
- let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;
+ let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
for input in &rels[1..] {
rel = LogicalPlanBuilder::intersect(
rel,
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
is_all,
)?
}
@@ -243,16 +242,16 @@ async fn intersect_rels(
async fn except_rels(
rels: &[Rel],
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
extensions: &Extensions,
is_all: bool,
) -> Result<LogicalPlan> {
- let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;
+ let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
for input in &rels[1..] {
rel = LogicalPlanBuilder::except(
rel,
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
is_all,
)?
}
@@ -262,7 +261,7 @@ async fn except_rels(
/// Convert Substrait Plan to DataFusion LogicalPlan
pub async fn from_substrait_plan(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
plan: &Plan,
) -> Result<LogicalPlan> {
// Register function extension
@@ -277,10 +276,10 @@ pub async fn from_substrait_plan(
match plan.relations[0].rel_type.as_ref() {
Some(rt) => match rt {
plan_rel::RelType::Rel(rel) => {
- Ok(from_substrait_rel(ctx, rel, &extensions).await?)
+ Ok(from_substrait_rel(state, rel, &extensions).await?)
},
plan_rel::RelType::Root(root) => {
- let plan = from_substrait_rel(ctx,
root.input.as_ref().unwrap(), &extensions).await?;
+ let plan = from_substrait_rel(state,
root.input.as_ref().unwrap(), &extensions).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
@@ -341,7 +340,7 @@ pub struct ExprContainer {
/// between systems. This is often useful for scenarios like pushdown where
filter
/// expressions need to be sent to remote systems.
pub async fn from_substrait_extended_expr(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
extended_expr: &ExtendedExpression,
) -> Result<ExprContainer> {
// Register function extension
@@ -370,7 +369,7 @@ pub async fn from_substrait_extended_expr(
}
}?;
let expr =
- from_substrait_rex(ctx, scalar_expr, &input_schema,
&extensions).await?;
+ from_substrait_rex(state, scalar_expr, &input_schema,
&extensions).await?;
let (output_type, expected_nullability) =
expr.data_type_and_nullable(&input_schema)?;
let output_field = Field::new("", output_type, expected_nullability);
@@ -561,7 +560,7 @@ fn make_renamed_schema(
#[allow(deprecated)]
#[async_recursion]
pub async fn from_substrait_rel(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
rel: &Rel,
extensions: &Extensions,
) -> Result<LogicalPlan> {
@@ -569,7 +568,7 @@ pub async fn from_substrait_rel(
Some(RelType::Project(p)) => {
if let Some(input) = p.input.as_ref() {
let mut input = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
);
let original_schema = input.schema().clone();
@@ -587,9 +586,13 @@ pub async fn from_substrait_rel(
let mut explicit_exprs: Vec<Expr> = vec![];
for expr in &p.expressions {
- let e =
- from_substrait_rex(ctx, expr, input.clone().schema(),
extensions)
- .await?;
+ let e = from_substrait_rex(
+ state,
+ expr,
+ input.clone().schema(),
+ extensions,
+ )
+ .await?;
// if the expression is WindowFunction, wrap in a Window
relation
if let Expr::WindowFunction(_) = &e {
// Adding the same expression here and in the project
below
@@ -617,11 +620,11 @@ pub async fn from_substrait_rel(
Some(RelType::Filter(filter)) => {
if let Some(input) = filter.input.as_ref() {
let input = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
);
if let Some(condition) = filter.condition.as_ref() {
let expr =
- from_substrait_rex(ctx, condition, input.schema(),
extensions)
+ from_substrait_rex(state, condition, input.schema(),
extensions)
.await?;
input.filter(expr)?.build()
} else {
@@ -634,7 +637,7 @@ pub async fn from_substrait_rel(
Some(RelType::Fetch(fetch)) => {
if let Some(input) = fetch.input.as_ref() {
let input = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
);
let offset = fetch.offset as usize;
// -1 means that ALL records should be returned
@@ -651,10 +654,10 @@ pub async fn from_substrait_rel(
Some(RelType::Sort(sort)) => {
if let Some(input) = sort.input.as_ref() {
let input = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
);
let sorts =
- from_substrait_sorts(ctx, &sort.sorts, input.schema(),
extensions)
+ from_substrait_sorts(state, &sort.sorts, input.schema(),
extensions)
.await?;
input.sort(sorts)?.build()
} else {
@@ -664,13 +667,13 @@ pub async fn from_substrait_rel(
Some(RelType::Aggregate(agg)) => {
if let Some(input) = agg.input.as_ref() {
let input = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, input, extensions).await?,
+ from_substrait_rel(state, input, extensions).await?,
);
let mut ref_group_exprs = vec![];
for e in &agg.grouping_expressions {
let x =
- from_substrait_rex(ctx, e, input.schema(),
extensions).await?;
+ from_substrait_rex(state, e, input.schema(),
extensions).await?;
ref_group_exprs.push(x);
}
@@ -681,7 +684,7 @@ pub async fn from_substrait_rel(
1 => {
group_exprs.extend_from_slice(
&from_substrait_grouping(
- ctx,
+ state,
&agg.groupings[0],
&ref_group_exprs,
input.schema(),
@@ -694,7 +697,7 @@ pub async fn from_substrait_rel(
let mut grouping_sets = vec![];
for grouping in &agg.groupings {
let grouping_set = from_substrait_grouping(
- ctx,
+ state,
grouping,
&ref_group_exprs,
input.schema(),
@@ -716,7 +719,7 @@ pub async fn from_substrait_rel(
for m in &agg.measures {
let filter = match &m.filter {
Some(fil) => Some(Box::new(
- from_substrait_rex(ctx, fil, input.schema(),
extensions)
+ from_substrait_rex(state, fil, input.schema(),
extensions)
.await?,
)),
None => None,
@@ -739,7 +742,7 @@ pub async fn from_substrait_rel(
let order_by = if !f.sorts.is_empty() {
Some(
from_substrait_sorts(
- ctx,
+ state,
&f.sorts,
input.schema(),
extensions,
@@ -751,7 +754,7 @@ pub async fn from_substrait_rel(
};
from_substrait_agg_func(
- ctx,
+ state,
f,
input.schema(),
extensions,
@@ -780,10 +783,12 @@ pub async fn from_substrait_rel(
}
let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, join.left.as_ref().unwrap(),
extensions).await?,
+ from_substrait_rel(state, join.left.as_ref().unwrap(),
extensions)
+ .await?,
);
let right = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, join.right.as_ref().unwrap(),
extensions).await?,
+ from_substrait_rel(state, join.right.as_ref().unwrap(),
extensions)
+ .await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;
@@ -796,7 +801,7 @@ pub async fn from_substrait_rel(
// Otherwise, build join with only the filter, without join keys
match &join.expression.as_ref() {
Some(expr) => {
- let on = from_substrait_rex(ctx, expr, &in_join_schema,
extensions)
+ let on = from_substrait_rex(state, expr, &in_join_schema,
extensions)
.await?;
// The join expression can contain both equal and
non-equal ops.
// As of datafusion 31.0.0, the equal and non equal join
conditions are in separate fields.
@@ -831,26 +836,44 @@ pub async fn from_substrait_rel(
}
Some(RelType::Cross(cross)) => {
let left = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, cross.left.as_ref().unwrap(),
extensions).await?,
+ from_substrait_rel(state, cross.left.as_ref().unwrap(),
extensions)
+ .await?,
);
let right = LogicalPlanBuilder::from(
- from_substrait_rel(ctx, cross.right.as_ref().unwrap(),
extensions)
+ from_substrait_rel(state, cross.right.as_ref().unwrap(),
extensions)
.await?,
);
let (left, right) = requalify_sides_if_needed(left, right)?;
left.cross_join(right.build()?)?.build()
}
Some(RelType::Read(read)) => {
- fn read_with_schema(
- df: DataFrame,
+ async fn read_with_schema(
+ state: &dyn SubstraitPlanningState,
+ table_ref: TableReference,
schema: DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
- ensure_schema_compatability(df.schema().to_owned(),
schema.clone())?;
+ let schema = schema.replace_qualifier(table_ref.clone());
+
+ let plan = {
+ let provider = match state.table(&table_ref).await? {
+ Some(ref provider) => Arc::clone(provider),
+ _ => return plan_err!("No table named '{table_ref}'"),
+ };
+
+ LogicalPlanBuilder::scan(
+ table_ref,
+ provider_as_source(Arc::clone(&provider)),
+ None,
+ )?
+ .build()?
+ };
+
+ ensure_schema_compatability(plan.schema(), schema.clone())?;
let schema = apply_masking(schema, projection)?;
- apply_projection(df, schema)
+ apply_projection(plan, schema)
}
let named_struct = read.base_schema.as_ref().ok_or_else(|| {
@@ -879,12 +902,13 @@ pub async fn from_substrait_rel(
},
};
- let t = ctx.table(table_reference.clone()).await?;
-
- let substrait_schema =
- substrait_schema.replace_qualifier(table_reference);
-
- read_with_schema(t, substrait_schema, &read.projection)
+ read_with_schema(
+ state,
+ table_reference,
+ substrait_schema,
+ &read.projection,
+ )
+ .await
}
Some(ReadType::VirtualTable(vt)) => {
if vt.values.is_empty() {
@@ -960,12 +984,14 @@ pub async fn from_substrait_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is
a valid one
let table_reference = TableReference::Bare { table:
name.into() };
- let t = ctx.table(table_reference.clone()).await?;
-
- let substrait_schema =
- substrait_schema.replace_qualifier(table_reference);
- read_with_schema(t, substrait_schema, &read.projection)
+ read_with_schema(
+ state,
+ table_reference,
+ substrait_schema,
+ &read.projection,
+ )
+ .await
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}",
&read.as_ref().read_type)
@@ -979,31 +1005,31 @@ pub async fn from_substrait_rel(
} else {
match set_op {
set_rel::SetOp::UnionAll => {
- union_rels(&set.inputs, ctx, extensions,
true).await
+ union_rels(&set.inputs, state, extensions,
true).await
}
set_rel::SetOp::UnionDistinct => {
- union_rels(&set.inputs, ctx, extensions,
false).await
+ union_rels(&set.inputs, state, extensions,
false).await
}
set_rel::SetOp::IntersectionPrimary => {
LogicalPlanBuilder::intersect(
- from_substrait_rel(ctx, &set.inputs[0],
extensions)
+ from_substrait_rel(state, &set.inputs[0],
extensions)
.await?,
- union_rels(&set.inputs[1..], ctx, extensions,
true)
+ union_rels(&set.inputs[1..], state,
extensions, true)
.await?,
false,
)
}
set_rel::SetOp::IntersectionMultiset => {
- intersect_rels(&set.inputs, ctx, extensions,
false).await
+ intersect_rels(&set.inputs, state, extensions,
false).await
}
set_rel::SetOp::IntersectionMultisetAll => {
- intersect_rels(&set.inputs, ctx, extensions,
true).await
+ intersect_rels(&set.inputs, state, extensions,
true).await
}
set_rel::SetOp::MinusPrimary => {
- except_rels(&set.inputs, ctx, extensions,
false).await
+ except_rels(&set.inputs, state, extensions,
false).await
}
set_rel::SetOp::MinusPrimaryAll => {
- except_rels(&set.inputs, ctx, extensions,
true).await
+ except_rels(&set.inputs, state, extensions,
true).await
}
_ => not_impl_err!("Unsupported set operator:
{set_op:?}"),
}
@@ -1015,8 +1041,7 @@ pub async fn from_substrait_rel(
let Some(ext_detail) = &extension.detail else {
return substrait_err!("Unexpected empty detail in
ExtensionLeafRel");
};
- let plan = ctx
- .state()
+ let plan = state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
@@ -1025,8 +1050,7 @@ pub async fn from_substrait_rel(
let Some(ext_detail) = &extension.detail else {
return substrait_err!("Unexpected empty detail in
ExtensionSingleRel");
};
- let plan = ctx
- .state()
+ let plan = state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
let Some(input_rel) = &extension.input else {
@@ -1034,7 +1058,7 @@ pub async fn from_substrait_rel(
"ExtensionSingleRel doesn't contains input rel. Try use
ExtensionLeafRel instead"
);
};
- let input_plan = from_substrait_rel(ctx, input_rel,
extensions).await?;
+ let input_plan = from_substrait_rel(state, input_rel,
extensions).await?;
let plan =
plan.with_exprs_and_inputs(plan.expressions(),
vec![input_plan])?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
@@ -1043,13 +1067,12 @@ pub async fn from_substrait_rel(
let Some(ext_detail) = &extension.detail else {
return substrait_err!("Unexpected empty detail in
ExtensionSingleRel");
};
- let plan = ctx
- .state()
+ let plan = state
.serializer_registry()
.deserialize_logical_plan(&ext_detail.type_url,
&ext_detail.value)?;
let mut inputs = Vec::with_capacity(extension.inputs.len());
for input in &extension.inputs {
- let input_plan = from_substrait_rel(ctx, input,
extensions).await?;
+ let input_plan = from_substrait_rel(state, input,
extensions).await?;
inputs.push(input_plan);
}
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
@@ -1059,7 +1082,7 @@ pub async fn from_substrait_rel(
let Some(input) = exchange.input.as_ref() else {
return substrait_err!("Unexpected empty input in ExchangeRel");
};
- let input = Arc::new(from_substrait_rel(ctx, input,
extensions).await?);
+ let input = Arc::new(from_substrait_rel(state, input,
extensions).await?);
let Some(exchange_kind) = &exchange.exchange_kind else {
return substrait_err!("Unexpected empty input in ExchangeRel");
@@ -1237,7 +1260,7 @@ impl NameTracker {
/// DataFusion schema may have MORE fields, but not the other way around.
/// 2. All fields are compatible. See [`ensure_field_compatability`] for
details
fn ensure_schema_compatability(
- table_schema: DFSchema,
+ table_schema: &DFSchema,
substrait_schema: DFSchema,
) -> Result<()> {
substrait_schema
@@ -1253,16 +1276,19 @@ fn ensure_schema_compatability(
/// This function returns a DataFrame with fields adjusted if necessary in the
event that the
/// Substrait schema is a subset of the DataFusion schema.
-fn apply_projection(table: DataFrame, substrait_schema: DFSchema) ->
Result<LogicalPlan> {
- let df_schema = table.schema().to_owned();
-
- let t = table.into_unoptimized_plan();
+fn apply_projection(
+ plan: LogicalPlan,
+ substrait_schema: DFSchema,
+) -> Result<LogicalPlan> {
+ let df_schema = plan.schema();
if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
- return Ok(t);
+ return Ok(plan);
}
- match t {
+ let df_schema = df_schema.to_owned();
+
+ match plan {
LogicalPlan::TableScan(mut scan) => {
let column_indices: Vec<usize> = substrait_schema
.strip_qualifiers()
@@ -1389,7 +1415,7 @@ fn from_substrait_jointype(join_type: i32) ->
Result<JoinType> {
/// Convert Substrait Sorts to DataFusion Exprs
pub async fn from_substrait_sorts(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
substrait_sorts: &Vec<SortField>,
input_schema: &DFSchema,
extensions: &Extensions,
@@ -1397,7 +1423,7 @@ pub async fn from_substrait_sorts(
let mut sorts: Vec<Sort> = vec![];
for s in substrait_sorts {
let expr =
- from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema,
extensions)
+ from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema,
extensions)
.await?;
let asc_nullfirst = match &s.sort_kind {
Some(k) => match k {
@@ -1439,14 +1465,15 @@ pub async fn from_substrait_sorts(
/// Convert Substrait Expressions to DataFusion Exprs
pub async fn from_substrait_rex_vec(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
exprs: &Vec<Expression>,
input_schema: &DFSchema,
extensions: &Extensions,
) -> Result<Vec<Expr>> {
let mut expressions: Vec<Expr> = vec![];
for expr in exprs {
- let expression = from_substrait_rex(ctx, expr, input_schema,
extensions).await?;
+ let expression =
+ from_substrait_rex(state, expr, input_schema, extensions).await?;
expressions.push(expression);
}
Ok(expressions)
@@ -1454,7 +1481,7 @@ pub async fn from_substrait_rex_vec(
/// Convert Substrait FunctionArguments to DataFusion Exprs
pub async fn from_substrait_func_args(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
arguments: &Vec<FunctionArgument>,
input_schema: &DFSchema,
extensions: &Extensions,
@@ -1463,7 +1490,7 @@ pub async fn from_substrait_func_args(
for arg in arguments {
let arg_expr = match &arg.arg_type {
Some(ArgType::Value(e)) => {
- from_substrait_rex(ctx, e, input_schema, extensions).await
+ from_substrait_rex(state, e, input_schema, extensions).await
}
_ => not_impl_err!("Function argument non-Value type not
supported"),
};
@@ -1474,7 +1501,7 @@ pub async fn from_substrait_func_args(
/// Convert Substrait AggregateFunction to DataFusion Expr
pub async fn from_substrait_agg_func(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
f: &AggregateFunction,
input_schema: &DFSchema,
extensions: &Extensions,
@@ -1483,7 +1510,7 @@ pub async fn from_substrait_agg_func(
distinct: bool,
) -> Result<Arc<Expr>> {
let args =
- from_substrait_func_args(ctx, &f.arguments, input_schema,
extensions).await?;
+ from_substrait_func_args(state, &f.arguments, input_schema,
extensions).await?;
let Some(function_name) = extensions.functions.get(&f.function_reference)
else {
return plan_err!(
@@ -1494,7 +1521,7 @@ pub async fn from_substrait_agg_func(
let function_name = substrait_fun_name(function_name);
// try udaf first, then built-in aggr fn.
- if let Ok(fun) = ctx.udaf(function_name) {
+ if let Ok(fun) = state.udaf(function_name) {
// deal with situation that count(*) got no arguments
let args = if fun.name() == "count" && args.is_empty() {
vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
@@ -1517,7 +1544,7 @@ pub async fn from_substrait_agg_func(
/// Convert Substrait Rex to DataFusion Expr
#[async_recursion]
pub async fn from_substrait_rex(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
e: &Expression,
input_schema: &DFSchema,
extensions: &Extensions,
@@ -1528,11 +1555,11 @@ pub async fn from_substrait_rex(
let substrait_list = s.options.as_ref();
Ok(Expr::InList(InList {
expr: Box::new(
- from_substrait_rex(ctx, substrait_expr, input_schema,
extensions)
+ from_substrait_rex(state, substrait_expr, input_schema,
extensions)
.await?,
),
list: from_substrait_rex_vec(
- ctx,
+ state,
substrait_list,
input_schema,
extensions,
@@ -1555,7 +1582,7 @@ pub async fn from_substrait_rex(
if if_expr.then.is_none() {
expr = Some(Box::new(
from_substrait_rex(
- ctx,
+ state,
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
@@ -1568,7 +1595,7 @@ pub async fn from_substrait_rex(
when_then_expr.push((
Box::new(
from_substrait_rex(
- ctx,
+ state,
if_expr.r#if.as_ref().unwrap(),
input_schema,
extensions,
@@ -1577,7 +1604,7 @@ pub async fn from_substrait_rex(
),
Box::new(
from_substrait_rex(
- ctx,
+ state,
if_expr.then.as_ref().unwrap(),
input_schema,
extensions,
@@ -1589,7 +1616,7 @@ pub async fn from_substrait_rex(
// Parse `else`
let else_expr = match &if_then.r#else {
Some(e) => Some(Box::new(
- from_substrait_rex(ctx, e, input_schema,
extensions).await?,
+ from_substrait_rex(state, e, input_schema,
extensions).await?,
)),
None => None,
};
@@ -1609,12 +1636,12 @@ pub async fn from_substrait_rex(
let fn_name = substrait_fun_name(fn_name);
let args =
- from_substrait_func_args(ctx, &f.arguments, input_schema,
extensions)
+ from_substrait_func_args(state, &f.arguments, input_schema,
extensions)
.await?;
// try to first match the requested function into registered udfs,
then built-in ops
// and finally built-in expressions
- if let Some(func) = ctx.state().scalar_functions().get(fn_name) {
+ if let Ok(func) = state.udf(fn_name) {
Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
func.to_owned(),
args,
@@ -1644,7 +1671,7 @@ pub async fn from_substrait_rex(
Ok(combined_expr)
} else if let Some(builder) =
BuiltinExprBuilder::try_from_name(fn_name) {
- builder.build(ctx, f, input_schema, extensions).await
+ builder.build(state, f, input_schema, extensions).await
} else {
not_impl_err!("Unsupported function name: {fn_name:?}")
}
@@ -1657,7 +1684,7 @@ pub async fn from_substrait_rex(
Some(output_type) => Ok(Expr::Cast(Cast::new(
Box::new(
from_substrait_rex(
- ctx,
+ state,
cast.as_ref().input.as_ref().unwrap().as_ref(),
input_schema,
extensions,
@@ -1679,9 +1706,9 @@ pub async fn from_substrait_rex(
let fn_name = substrait_fun_name(fn_name);
// check udwf first, then udaf, then built-in window and aggregate
functions
- let fun = if let Ok(udwf) = ctx.udwf(fn_name) {
+ let fun = if let Ok(udwf) = state.udwf(fn_name) {
Ok(WindowFunctionDefinition::WindowUDF(udwf))
- } else if let Ok(udaf) = ctx.udaf(fn_name) {
+ } else if let Ok(udaf) = state.udaf(fn_name) {
Ok(WindowFunctionDefinition::AggregateUDF(udaf))
} else {
not_impl_err!(
@@ -1692,7 +1719,7 @@ pub async fn from_substrait_rex(
}?;
let order_by =
- from_substrait_sorts(ctx, &window.sorts, input_schema,
extensions)
+ from_substrait_sorts(state, &window.sorts, input_schema,
extensions)
.await?;
let bound_units =
@@ -1715,14 +1742,14 @@ pub async fn from_substrait_rex(
Ok(Expr::WindowFunction(expr::WindowFunction {
fun,
args: from_substrait_func_args(
- ctx,
+ state,
&window.arguments,
input_schema,
extensions,
)
.await?,
partition_by: from_substrait_rex_vec(
- ctx,
+ state,
&window.partitions,
input_schema,
extensions,
@@ -1747,13 +1774,13 @@ pub async fn from_substrait_rex(
let haystack_expr = &in_predicate.haystack;
if let Some(haystack_expr) = haystack_expr {
let haystack_expr =
- from_substrait_rel(ctx, haystack_expr,
extensions)
+ from_substrait_rel(state, haystack_expr,
extensions)
.await?;
let outer_refs = haystack_expr.all_out_ref_exprs();
Ok(Expr::InSubquery(InSubquery {
expr: Box::new(
from_substrait_rex(
- ctx,
+ state,
needle_expr,
input_schema,
extensions,
@@ -1773,7 +1800,7 @@ pub async fn from_substrait_rex(
}
SubqueryType::Scalar(query) => {
let plan = from_substrait_rel(
- ctx,
+ state,
&(query.input.clone()).unwrap_or_default(),
extensions,
)
@@ -1790,7 +1817,7 @@ pub async fn from_substrait_rex(
PredicateOp::Exists => {
let relation = &predicate.tuples;
let plan = from_substrait_rel(
- ctx,
+ state,
&relation.clone().unwrap_or_default(),
extensions,
)
@@ -2772,7 +2799,7 @@ fn from_substrait_null(
#[allow(deprecated)]
async fn from_substrait_grouping(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
grouping: &Grouping,
expressions: &[Expr],
input_schema: &DFSchemaRef,
@@ -2781,7 +2808,7 @@ async fn from_substrait_grouping(
let mut group_exprs = vec![];
if !grouping.grouping_expressions.is_empty() {
for e in &grouping.grouping_expressions {
- let expr = from_substrait_rex(ctx, e, input_schema,
extensions).await?;
+ let expr = from_substrait_rex(state, e, input_schema,
extensions).await?;
group_exprs.push(expr);
}
return Ok(group_exprs);
@@ -2834,23 +2861,29 @@ impl BuiltinExprBuilder {
pub async fn build(
self,
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
f: &ScalarFunction,
input_schema: &DFSchema,
extensions: &Extensions,
) -> Result<Expr> {
match self.expr_name.as_str() {
"like" => {
- Self::build_like_expr(ctx, false, f, input_schema,
extensions).await
+ Self::build_like_expr(state, false, f, input_schema,
extensions).await
}
"ilike" => {
- Self::build_like_expr(ctx, true, f, input_schema,
extensions).await
+ Self::build_like_expr(state, true, f, input_schema,
extensions).await
}
"not" | "negative" | "negate" | "is_null" | "is_not_null" |
"is_true"
| "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
| "is_not_unknown" => {
- Self::build_unary_expr(ctx, &self.expr_name, f, input_schema,
extensions)
- .await
+ Self::build_unary_expr(
+ state,
+ &self.expr_name,
+ f,
+ input_schema,
+ extensions,
+ )
+ .await
}
_ => {
not_impl_err!("Unsupported builtin expression: {}",
self.expr_name)
@@ -2859,7 +2892,7 @@ impl BuiltinExprBuilder {
}
async fn build_unary_expr(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
fn_name: &str,
f: &ScalarFunction,
input_schema: &DFSchema,
@@ -2872,7 +2905,7 @@ impl BuiltinExprBuilder {
return substrait_err!("Invalid arguments type for {fn_name} expr");
};
let arg =
- from_substrait_rex(ctx, expr_substrait, input_schema,
extensions).await?;
+ from_substrait_rex(state, expr_substrait, input_schema,
extensions).await?;
let arg = Box::new(arg);
let expr = match fn_name {
@@ -2893,7 +2926,7 @@ impl BuiltinExprBuilder {
}
async fn build_like_expr(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
case_insensitive: bool,
f: &ScalarFunction,
input_schema: &DFSchema,
@@ -2908,12 +2941,13 @@ impl BuiltinExprBuilder {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
let expr =
- from_substrait_rex(ctx, expr_substrait, input_schema,
extensions).await?;
+ from_substrait_rex(state, expr_substrait, input_schema,
extensions).await?;
let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type
else {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
let pattern =
- from_substrait_rex(ctx, pattern_substrait, input_schema,
extensions).await?;
+ from_substrait_rex(state, pattern_substrait, input_schema,
extensions)
+ .await?;
// Default case: escape character is Literal(Utf8(None))
let escape_char = if f.arguments.len() == 3 {
@@ -2922,9 +2956,13 @@ impl BuiltinExprBuilder {
return substrait_err!("Invalid arguments type for `{fn_name}`
expr");
};
- let escape_char_expr =
- from_substrait_rex(ctx, escape_char_substrait, input_schema,
extensions)
- .await?;
+ let escape_char_expr = from_substrait_rex(
+ state,
+ escape_char_substrait,
+ input_schema,
+ extensions,
+ )
+ .await?;
match escape_char_expr {
Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
diff --git a/datafusion/substrait/src/logical_plan/mod.rs
b/datafusion/substrait/src/logical_plan/mod.rs
index 6f8b8e493f..9e2fa9fa49 100644
--- a/datafusion/substrait/src/logical_plan/mod.rs
+++ b/datafusion/substrait/src/logical_plan/mod.rs
@@ -17,3 +17,4 @@
pub mod consumer;
pub mod producer;
+pub mod state;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs
b/datafusion/substrait/src/logical_plan/producer.rs
index 4d864e4334..29019dfd74 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -29,7 +29,7 @@ use datafusion::{
arrow::datatypes::{DataType, TimeUnit},
error::{DataFusionError, Result},
logical_expr::{WindowFrame, WindowFrameBound},
- prelude::{JoinType, SessionContext},
+ prelude::JoinType,
scalar::ScalarValue,
};
@@ -100,8 +100,13 @@ use substrait::{
version,
};
+use super::state::SubstraitPlanningState;
+
/// Convert DataFusion LogicalPlan to Substrait Plan
-pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) ->
Result<Box<Plan>> {
+pub fn to_substrait_plan(
+ plan: &LogicalPlan,
+ state: &dyn SubstraitPlanningState,
+) -> Result<Box<Plan>> {
let mut extensions = Extensions::default();
// Parse relation nodes
// Generate PlanRel(s)
@@ -113,7 +118,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx:
&SessionContext) -> Result<Box
let plan_rels = vec![PlanRel {
rel_type: Some(plan_rel::RelType::Root(RelRoot {
- input: Some(*to_substrait_rel(&plan, ctx, &mut extensions)?),
+ input: Some(*to_substrait_rel(&plan, state, &mut extensions)?),
names: to_substrait_named_struct(plan.schema())?.names,
})),
}];
@@ -144,7 +149,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx:
&SessionContext) -> Result<Box
pub fn to_substrait_extended_expr(
exprs: &[(&Expr, &Field)],
schema: &DFSchemaRef,
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
) -> Result<Box<ExtendedExpression>> {
let mut extensions = Extensions::default();
@@ -152,7 +157,7 @@ pub fn to_substrait_extended_expr(
.iter()
.map(|(expr, field)| {
let substrait_expr = to_substrait_rex(
- ctx,
+ state,
expr,
schema,
/*col_ref_offset=*/ 0,
@@ -183,7 +188,7 @@ pub fn to_substrait_extended_expr(
#[allow(deprecated)]
pub fn to_substrait_rel(
plan: &LogicalPlan,
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
extensions: &mut Extensions,
) -> Result<Box<Rel>> {
match plan {
@@ -284,7 +289,7 @@ pub fn to_substrait_rel(
let expressions = p
.expr
.iter()
- .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0,
extensions))
+ .map(|e| to_substrait_rex(state, e, p.input.schema(), 0,
extensions))
.collect::<Result<Vec<_>>>()?;
let emit_kind = create_project_remapping(
@@ -300,16 +305,16 @@ pub fn to_substrait_rel(
Ok(Box::new(Rel {
rel_type: Some(RelType::Project(Box::new(ProjectRel {
common: Some(common),
- input: Some(to_substrait_rel(p.input.as_ref(), ctx,
extensions)?),
+ input: Some(to_substrait_rel(p.input.as_ref(), state,
extensions)?),
expressions,
advanced_extension: None,
}))),
}))
}
LogicalPlan::Filter(filter) => {
- let input = to_substrait_rel(filter.input.as_ref(), ctx,
extensions)?;
+ let input = to_substrait_rel(filter.input.as_ref(), state,
extensions)?;
let filter_expr = to_substrait_rex(
- ctx,
+ state,
&filter.predicate,
filter.input.schema(),
0,
@@ -325,7 +330,7 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Limit(limit) => {
- let input = to_substrait_rel(limit.input.as_ref(), ctx,
extensions)?;
+ let input = to_substrait_rel(limit.input.as_ref(), state,
extensions)?;
let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
return not_impl_err!("Non-literal limit fetch");
};
@@ -344,11 +349,11 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Sort(sort) => {
- let input = to_substrait_rel(sort.input.as_ref(), ctx,
extensions)?;
+ let input = to_substrait_rel(sort.input.as_ref(), state,
extensions)?;
let sort_fields = sort
.expr
.iter()
- .map(|e| substrait_sort_field(ctx, e, sort.input.schema(),
extensions))
+ .map(|e| substrait_sort_field(state, e, sort.input.schema(),
extensions))
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
rel_type: Some(RelType::Sort(Box::new(SortRel {
@@ -360,9 +365,9 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Aggregate(agg) => {
- let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
+ let input = to_substrait_rel(agg.input.as_ref(), state,
extensions)?;
let (grouping_expressions, groupings) = to_substrait_groupings(
- ctx,
+ state,
&agg.group_expr,
agg.input.schema(),
extensions,
@@ -370,7 +375,9 @@ pub fn to_substrait_rel(
let measures = agg
.aggr_expr
.iter()
- .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(),
extensions))
+ .map(|e| {
+ to_substrait_agg_measure(state, e, agg.input.schema(),
extensions)
+ })
.collect::<Result<Vec<_>>>()?;
Ok(Box::new(Rel {
@@ -386,7 +393,7 @@ pub fn to_substrait_rel(
}
LogicalPlan::Distinct(Distinct::All(plan)) => {
// Use Substrait's AggregateRel with empty measures to represent
`select distinct`
- let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?;
+ let input = to_substrait_rel(plan.as_ref(), state, extensions)?;
// Get grouping keys from the input relation's number of output
fields
let grouping = (0..plan.schema().fields().len())
.map(substrait_field_ref)
@@ -407,8 +414,8 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Join(join) => {
- let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?;
- let right = to_substrait_rel(join.right.as_ref(), ctx,
extensions)?;
+ let left = to_substrait_rel(join.left.as_ref(), state,
extensions)?;
+ let right = to_substrait_rel(join.right.as_ref(), state,
extensions)?;
let join_type = to_substrait_jointype(join.join_type);
// we only support basic joins so return an error for anything not
yet supported
match join.join_constraint {
@@ -421,7 +428,7 @@ pub fn to_substrait_rel(
let in_join_schema = join.left.schema().join(join.right.schema())?;
let join_filter = match &join.filter {
Some(filter) => Some(to_substrait_rex(
- ctx,
+ state,
filter,
&Arc::new(in_join_schema),
0,
@@ -438,7 +445,7 @@ pub fn to_substrait_rel(
Operator::Eq
};
let join_on = to_substrait_join_expr(
- ctx,
+ state,
&join.on,
eq_op,
join.left.schema(),
@@ -479,13 +486,13 @@ pub fn to_substrait_rel(
LogicalPlan::SubqueryAlias(alias) => {
// Do nothing if encounters SubqueryAlias
// since there is no corresponding relation type in Substrait
- to_substrait_rel(alias.input.as_ref(), ctx, extensions)
+ to_substrait_rel(alias.input.as_ref(), state, extensions)
}
LogicalPlan::Union(union) => {
let input_rels = union
.inputs
.iter()
- .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions))
+ .map(|input| to_substrait_rel(input.as_ref(), state,
extensions))
.collect::<Result<Vec<_>>>()?
.into_iter()
.map(|ptr| *ptr)
@@ -500,7 +507,7 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Window(window) => {
- let input = to_substrait_rel(window.input.as_ref(), ctx,
extensions)?;
+ let input = to_substrait_rel(window.input.as_ref(), state,
extensions)?;
// create a field reference for each input field
let mut expressions = (0..window.input.schema().fields().len())
@@ -510,7 +517,7 @@ pub fn to_substrait_rel(
// process and add each window function expression
for expr in &window.window_expr {
expressions.push(to_substrait_rex(
- ctx,
+ state,
expr,
window.input.schema(),
0,
@@ -539,7 +546,7 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Repartition(repartition) => {
- let input = to_substrait_rel(repartition.input.as_ref(), ctx,
extensions)?;
+ let input = to_substrait_rel(repartition.input.as_ref(), state,
extensions)?;
let partition_count = match repartition.partitioning_scheme {
Partitioning::RoundRobinBatch(num) => num,
Partitioning::Hash(_, num) => num,
@@ -585,8 +592,7 @@ pub fn to_substrait_rel(
}))
}
LogicalPlan::Extension(extension_plan) => {
- let extension_bytes = ctx
- .state()
+ let extension_bytes = state
.serializer_registry()
.serialize_logical_plan(extension_plan.node.as_ref())?;
let detail = ProtoAny {
@@ -597,7 +603,7 @@ pub fn to_substrait_rel(
.node
.inputs()
.into_iter()
- .map(|plan| to_substrait_rel(plan, ctx, extensions))
+ .map(|plan| to_substrait_rel(plan, state, extensions))
.collect::<Result<Vec<_>>>()?;
let rel_type = match inputs_rel.len() {
0 => RelType::ExtensionLeaf(ExtensionLeafRel {
@@ -687,7 +693,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) ->
Result<NamedStruct> {
}
fn to_substrait_join_expr(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
join_conditions: &Vec<(Expr, Expr)>,
eq_op: Operator,
left_schema: &DFSchemaRef,
@@ -698,10 +704,10 @@ fn to_substrait_join_expr(
let mut exprs: Vec<Expression> = vec![];
for (left, right) in join_conditions {
// Parse left
- let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?;
+ let l = to_substrait_rex(state, left, left_schema, 0, extensions)?;
// Parse right
let r = to_substrait_rex(
- ctx,
+ state,
right,
right_schema,
left_schema.fields().len(), // offset to return the correct index
@@ -770,7 +776,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
#[allow(deprecated)]
pub fn parse_flat_grouping_exprs(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
@@ -780,7 +786,7 @@ pub fn parse_flat_grouping_exprs(
let mut grouping_expressions = vec![];
for e in exprs {
- let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
+ let rex = to_substrait_rex(state, e, schema, 0, extensions)?;
grouping_expressions.push(rex.clone());
ref_group_exprs.push(rex);
expression_references.push((ref_group_exprs.len() - 1) as u32);
@@ -792,7 +798,7 @@ pub fn parse_flat_grouping_exprs(
}
pub fn to_substrait_groupings(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
exprs: &[Expr],
schema: &DFSchemaRef,
extensions: &mut Extensions,
@@ -808,7 +814,7 @@ pub fn to_substrait_groupings(
.iter()
.map(|set| {
parse_flat_grouping_exprs(
- ctx,
+ state,
set,
schema,
extensions,
@@ -826,7 +832,7 @@ pub fn to_substrait_groupings(
.rev()
.map(|set| {
parse_flat_grouping_exprs(
- ctx,
+ state,
set,
schema,
extensions,
@@ -837,7 +843,7 @@ pub fn to_substrait_groupings(
}
},
_ => Ok(vec![parse_flat_grouping_exprs(
- ctx,
+ state,
exprs,
schema,
extensions,
@@ -845,7 +851,7 @@ pub fn to_substrait_groupings(
)?]),
},
_ => Ok(vec![parse_flat_grouping_exprs(
- ctx,
+ state,
exprs,
schema,
extensions,
@@ -857,7 +863,7 @@ pub fn to_substrait_groupings(
#[allow(deprecated)]
pub fn to_substrait_agg_measure(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
expr: &Expr,
schema: &DFSchemaRef,
extensions: &mut Extensions,
@@ -865,13 +871,13 @@ pub fn to_substrait_agg_measure(
match expr {
Expr::AggregateFunction(expr::AggregateFunction { func, args,
distinct, filter, order_by, null_treatment: _, }) => {
let sorts = if let Some(order_by) = order_by {
- order_by.iter().map(|expr|
to_substrait_sort_field(ctx, expr, schema,
extensions)).collect::<Result<Vec<_>>>()?
+ order_by.iter().map(|expr|
to_substrait_sort_field(state, expr, schema,
extensions)).collect::<Result<Vec<_>>>()?
} else {
vec![]
};
let mut arguments: Vec<FunctionArgument> = vec![];
for arg in args {
- arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) });
+ arguments.push(FunctionArgument { arg_type:
Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) });
}
let function_anchor =
extensions.register_function(func.name().to_string());
Ok(Measure {
@@ -889,14 +895,14 @@ pub fn to_substrait_agg_measure(
options: vec![],
}),
filter: match filter {
- Some(f) => Some(to_substrait_rex(ctx, f, schema,
0, extensions)?),
+ Some(f) => Some(to_substrait_rex(state, f, schema,
0, extensions)?),
None => None
}
})
}
Expr::Alias(Alias{expr,..})=> {
- to_substrait_agg_measure(ctx, expr, schema, extensions)
+ to_substrait_agg_measure(state, expr, schema, extensions)
}
_ => internal_err!(
"Expression must be compatible with aggregation. Unsupported
expression: {:?}. ExpressionType: {:?}",
@@ -908,7 +914,7 @@ pub fn to_substrait_agg_measure(
/// Converts sort expression to corresponding substrait `SortField`
fn to_substrait_sort_field(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
sort: &Sort,
schema: &DFSchemaRef,
extensions: &mut Extensions,
@@ -920,7 +926,7 @@ fn to_substrait_sort_field(
(false, false) => SortDirection::DescNullsLast,
};
Ok(SortField {
- expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?),
+ expr: Some(to_substrait_rex(state, &sort.expr, schema, 0,
extensions)?),
sort_kind: Some(SortKind::Direction(sort_kind.into())),
})
}
@@ -977,7 +983,7 @@ pub fn make_binary_op_scalar_func(
/// * `extensions` - Substrait extension info. Contains registered function
information
#[allow(deprecated)]
pub fn to_substrait_rex(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
expr: &Expr,
schema: &DFSchemaRef,
col_ref_offset: usize,
@@ -991,10 +997,10 @@ pub fn to_substrait_rex(
}) => {
let substrait_list = list
.iter()
- .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset,
extensions))
+ .map(|x| to_substrait_rex(state, x, schema, col_ref_offset,
extensions))
.collect::<Result<Vec<Expression>>>()?;
let substrait_expr =
- to_substrait_rex(ctx, expr, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
let substrait_or_list = Expression {
rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList
{
@@ -1026,7 +1032,7 @@ pub fn to_substrait_rex(
for arg in &fun.args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
- ctx,
+ state,
arg,
schema,
col_ref_offset,
@@ -1055,11 +1061,11 @@ pub fn to_substrait_rex(
if *negated {
// `expr NOT BETWEEN low AND high` can be translated into
(expr < low OR high < expr)
let substrait_expr =
- to_substrait_rex(ctx, expr, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
let substrait_low =
- to_substrait_rex(ctx, low, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, low, schema, col_ref_offset,
extensions)?;
let substrait_high =
- to_substrait_rex(ctx, high, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, high, schema, col_ref_offset,
extensions)?;
let l_expr = make_binary_op_scalar_func(
&substrait_expr,
@@ -1083,11 +1089,11 @@ pub fn to_substrait_rex(
} else {
// `expr BETWEEN low AND high` can be translated into (low <=
expr AND expr <= high)
let substrait_expr =
- to_substrait_rex(ctx, expr, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
let substrait_low =
- to_substrait_rex(ctx, low, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, low, schema, col_ref_offset,
extensions)?;
let substrait_high =
- to_substrait_rex(ctx, high, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, high, schema, col_ref_offset,
extensions)?;
let l_expr = make_binary_op_scalar_func(
&substrait_low,
@@ -1115,8 +1121,8 @@ pub fn to_substrait_rex(
substrait_field_ref(index + col_ref_offset)
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
- let l = to_substrait_rex(ctx, left, schema, col_ref_offset,
extensions)?;
- let r = to_substrait_rex(ctx, right, schema, col_ref_offset,
extensions)?;
+ let l = to_substrait_rex(state, left, schema, col_ref_offset,
extensions)?;
+ let r = to_substrait_rex(state, right, schema, col_ref_offset,
extensions)?;
Ok(make_binary_op_scalar_func(&l, &r, *op, extensions))
}
@@ -1131,7 +1137,7 @@ pub fn to_substrait_rex(
// Base expression exists
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
- ctx,
+ state,
e,
schema,
col_ref_offset,
@@ -1144,14 +1150,14 @@ pub fn to_substrait_rex(
for (r#if, then) in when_then_expr {
ifs.push(IfClause {
r#if: Some(to_substrait_rex(
- ctx,
+ state,
r#if,
schema,
col_ref_offset,
extensions,
)?),
then: Some(to_substrait_rex(
- ctx,
+ state,
then,
schema,
col_ref_offset,
@@ -1163,7 +1169,7 @@ pub fn to_substrait_rex(
// Parse outer `else`
let r#else: Option<Box<Expression>> = match else_expr {
Some(e) => Some(Box::new(to_substrait_rex(
- ctx,
+ state,
e,
schema,
col_ref_offset,
@@ -1182,7 +1188,7 @@ pub fn to_substrait_rex(
substrait::proto::expression::Cast {
r#type: Some(to_substrait_type(data_type, true)?),
input: Some(Box::new(to_substrait_rex(
- ctx,
+ state,
expr,
schema,
col_ref_offset,
@@ -1195,7 +1201,7 @@ pub fn to_substrait_rex(
}
Expr::Literal(value) => to_substrait_literal_expr(value, extensions),
Expr::Alias(Alias { expr, .. }) => {
- to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)
+ to_substrait_rex(state, expr, schema, col_ref_offset, extensions)
}
Expr::WindowFunction(WindowFunction {
fun,
@@ -1212,7 +1218,7 @@ pub fn to_substrait_rex(
for arg in args {
arguments.push(FunctionArgument {
arg_type: Some(ArgType::Value(to_substrait_rex(
- ctx,
+ state,
arg,
schema,
col_ref_offset,
@@ -1223,12 +1229,12 @@ pub fn to_substrait_rex(
// partition by expressions
let partition_by = partition_by
.iter()
- .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset,
extensions))
+ .map(|e| to_substrait_rex(state, e, schema, col_ref_offset,
extensions))
.collect::<Result<Vec<_>>>()?;
// order by expressions
let order_by = order_by
.iter()
- .map(|e| substrait_sort_field(ctx, e, schema, extensions))
+ .map(|e| substrait_sort_field(state, e, schema, extensions))
.collect::<Result<Vec<_>>>()?;
// window frame
let bounds = to_substrait_bounds(window_frame)?;
@@ -1249,7 +1255,7 @@ pub fn to_substrait_rex(
escape_char,
case_insensitive,
}) => make_substrait_like_expr(
- ctx,
+ state,
*case_insensitive,
*negated,
expr,
@@ -1265,10 +1271,10 @@ pub fn to_substrait_rex(
negated,
}) => {
let substrait_expr =
- to_substrait_rex(ctx, expr, schema, col_ref_offset,
extensions)?;
+ to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
let subquery_plan =
- to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?;
+ to_substrait_rel(subquery.subquery.as_ref(), state,
extensions)?;
let substrait_subquery = Expression {
rex_type: Some(RexType::Subquery(Box::new(Subquery {
@@ -1301,7 +1307,7 @@ pub fn to_substrait_rex(
}
}
Expr::Not(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"not",
arg,
schema,
@@ -1309,7 +1315,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_null",
arg,
schema,
@@ -1317,7 +1323,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_not_null",
arg,
schema,
@@ -1325,7 +1331,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_true",
arg,
schema,
@@ -1333,7 +1339,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_false",
arg,
schema,
@@ -1341,7 +1347,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_unknown",
arg,
schema,
@@ -1349,7 +1355,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_not_true",
arg,
schema,
@@ -1357,7 +1363,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_not_false",
arg,
schema,
@@ -1365,7 +1371,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"is_not_unknown",
arg,
schema,
@@ -1373,7 +1379,7 @@ pub fn to_substrait_rex(
extensions,
),
Expr::Negative(arg) => to_substrait_unary_scalar_fn(
- ctx,
+ state,
"negate",
arg,
schema,
@@ -1674,7 +1680,7 @@ fn make_substrait_window_function(
#[allow(deprecated)]
#[allow(clippy::too_many_arguments)]
fn make_substrait_like_expr(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
ignore_case: bool,
negated: bool,
expr: &Expr,
@@ -1689,8 +1695,8 @@ fn make_substrait_like_expr(
} else {
extensions.register_function("like".to_string())
};
- let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset,
extensions)?;
- let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset,
extensions)?;
+ let expr = to_substrait_rex(state, expr, schema, col_ref_offset,
extensions)?;
+ let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset,
extensions)?;
let escape_char = to_substrait_literal_expr(
&ScalarValue::Utf8(escape_char.map(|c| c.to_string())),
extensions,
@@ -2088,7 +2094,7 @@ fn to_substrait_literal_expr(
/// Util to generate substrait [RexType::ScalarFunction] with one argument
fn to_substrait_unary_scalar_fn(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
fn_name: &str,
arg: &Expr,
schema: &DFSchemaRef,
@@ -2096,7 +2102,8 @@ fn to_substrait_unary_scalar_fn(
extensions: &mut Extensions,
) -> Result<Expression> {
let function_anchor = extensions.register_function(fn_name.to_string());
- let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset,
extensions)?;
+ let substrait_expr =
+ to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?;
Ok(Expression {
rex_type: Some(RexType::ScalarFunction(ScalarFunction {
@@ -2137,7 +2144,7 @@ fn try_to_substrait_field_reference(
}
fn substrait_sort_field(
- ctx: &SessionContext,
+ state: &dyn SubstraitPlanningState,
sort: &Sort,
schema: &DFSchemaRef,
extensions: &mut Extensions,
@@ -2147,7 +2154,7 @@ fn substrait_sort_field(
asc,
nulls_first,
} = sort;
- let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?;
+ let e = to_substrait_rex(state, expr, schema, 0, extensions)?;
let d = match (asc, nulls_first) {
(true, true) => SortDirection::AscNullsFirst,
(true, false) => SortDirection::AscNullsLast,
@@ -2190,6 +2197,7 @@ mod test {
use datafusion::arrow::datatypes::{Field, Fields, Schema};
use datafusion::common::scalar::ScalarStructBuilder;
use datafusion::common::DFSchema;
+ use datafusion::execution::SessionStateBuilder;
#[test]
fn round_trip_literals() -> Result<()> {
@@ -2433,15 +2441,15 @@ mod test {
#[tokio::test]
async fn extended_expressions() -> Result<()> {
- let ctx = SessionContext::new();
+ let state = SessionStateBuilder::default().build();
// One expression, empty input schema
let expr = Expr::Literal(ScalarValue::Int32(Some(42)));
let field = Field::new("out", DataType::Int32, false);
let empty_schema = DFSchemaRef::new(DFSchema::empty());
let substrait =
- to_substrait_extended_expr(&[(&expr, &field)], &empty_schema,
&ctx)?;
- let roundtrip_expr = from_substrait_extended_expr(&ctx,
&substrait).await?;
+ to_substrait_extended_expr(&[(&expr, &field)], &empty_schema,
&state)?;
+ let roundtrip_expr = from_substrait_extended_expr(&state,
&substrait).await?;
assert_eq!(roundtrip_expr.input_schema, empty_schema);
assert_eq!(roundtrip_expr.exprs.len(), 1);
@@ -2463,9 +2471,9 @@ mod test {
let substrait = to_substrait_extended_expr(
&[(&expr1, &out1), (&expr2, &out2)],
&input_schema,
- &ctx,
+ &state,
)?;
- let roundtrip_expr = from_substrait_extended_expr(&ctx,
&substrait).await?;
+ let roundtrip_expr = from_substrait_extended_expr(&state,
&substrait).await?;
assert_eq!(roundtrip_expr.input_schema, input_schema);
assert_eq!(roundtrip_expr.exprs.len(), 2);
@@ -2485,14 +2493,14 @@ mod test {
#[tokio::test]
async fn invalid_extended_expression() {
- let ctx = SessionContext::new();
+ let state = SessionStateBuilder::default().build();
// Not ok if input schema is missing field referenced by expr
let expr = Expr::Column("missing".into());
let field = Field::new("out", DataType::Int32, false);
let empty_schema = DFSchemaRef::new(DFSchema::empty());
- let err = to_substrait_extended_expr(&[(&expr, &field)],
&empty_schema, &ctx);
+ let err = to_substrait_extended_expr(&[(&expr, &field)],
&empty_schema, &state);
assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
}
diff --git a/datafusion/substrait/src/logical_plan/state.rs
b/datafusion/substrait/src/logical_plan/state.rs
new file mode 100644
index 0000000000..0bd749c110
--- /dev/null
+++ b/datafusion/substrait/src/logical_plan/state.rs
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use datafusion::{
+ catalog::TableProvider,
+ error::{DataFusionError, Result},
+ execution::{registry::SerializerRegistry, FunctionRegistry, SessionState},
+ sql::TableReference,
+};
+
+/// This trait provides the context needed to transform a substrait plan into a
+/// [`datafusion::logical_expr::LogicalPlan`] (via
[`super::consumer::from_substrait_plan`])
+/// and back again into a substrait plan (via
[`super::producer::to_substrait_plan`]).
+///
+/// The context is declared as a trait to decouple the substrait plan encoder /
+/// decoder from the [`SessionState`], potentially allowing users to define
+/// their own slimmer context just for serializing and deserializing substrait.
+///
+/// [`SessionState`] implements this trait.
+#[async_trait]
+pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry {
+ /// Return [SerializerRegistry] for extensions
+ fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry>;
+
+ async fn table(
+ &self,
+ reference: &TableReference,
+ ) -> Result<Option<Arc<dyn TableProvider>>>;
+}
+
+#[async_trait]
+impl SubstraitPlanningState for SessionState {
+ fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry> {
+ self.serializer_registry()
+ }
+
+ async fn table(
+ &self,
+ reference: &TableReference,
+ ) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
+ let table = reference.table().to_string();
+ let schema = self.schema_for_ref(reference.clone())?;
+ let table_provider = schema.table(&table).await?;
+ Ok(table_provider)
+ }
+}
diff --git a/datafusion/substrait/src/serializer.rs
b/datafusion/substrait/src/serializer.rs
index 6b81e33dfc..4278671777 100644
--- a/datafusion/substrait/src/serializer.rs
+++ b/datafusion/substrait/src/serializer.rs
@@ -38,7 +38,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path:
&str) -> Result<()
pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) ->
Result<Vec<u8>> {
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
- let proto = producer::to_substrait_plan(&plan, ctx)?;
+ let proto = producer::to_substrait_plan(&plan, &ctx.state())?;
let mut protobuf_out = Vec::<u8>::new();
proto.encode(&mut protobuf_out).map_err(|e| {
diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs
b/datafusion/substrait/tests/cases/consumer_integration.rs
index bc38ef8297..219f656bb4 100644
--- a/datafusion/substrait/tests/cases/consumer_integration.rs
+++ b/datafusion/substrait/tests/cases/consumer_integration.rs
@@ -41,7 +41,7 @@ mod tests {
.expect("failed to parse json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
- let plan = from_substrait_plan(&ctx, &proto).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto).await?;
Ok(format!("{}", plan))
}
diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs
b/datafusion/substrait/tests/cases/emit_kind_tests.rs
index ac66177ed7..08537d0d11 100644
--- a/datafusion/substrait/tests/cases/emit_kind_tests.rs
+++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs
@@ -33,7 +33,7 @@ mod tests {
"tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json",
);
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
let plan_str = format!("{}", plan);
@@ -51,7 +51,7 @@ mod tests {
"tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json",
);
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
let plan_str = format!("{}", plan);
@@ -91,8 +91,8 @@ mod tests {
\n TableScan: data"
);
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
// note how the Projections are not flattened
assert_eq!(
format!("{}", plan2),
@@ -115,8 +115,8 @@ mod tests {
\n TableScan: data"
);
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan1str = format!("{plan}");
let plan2str = format!("{plan2}");
diff --git a/datafusion/substrait/tests/cases/function_test.rs
b/datafusion/substrait/tests/cases/function_test.rs
index b136b0af19..0438084561 100644
--- a/datafusion/substrait/tests/cases/function_test.rs
+++ b/datafusion/substrait/tests/cases/function_test.rs
@@ -29,7 +29,7 @@ mod tests {
async fn contains_function_test() -> Result<()> {
let proto_plan =
read_json("tests/testdata/contains_plan.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
let plan_str = format!("{}", plan);
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs
b/datafusion/substrait/tests/cases/logical_plans.rs
index f4e34af35d..65f404bbda 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -38,7 +38,7 @@ mod tests {
let proto_plan =
read_json("tests/testdata/test_plans/select_not_bool.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(
format!("{}", plan),
@@ -63,7 +63,7 @@ mod tests {
let proto_plan =
read_json("tests/testdata/test_plans/select_window.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(
format!("{}", plan),
@@ -82,7 +82,7 @@ mod tests {
let proto_plan =
read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json");
let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))");
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index d4e2d48885..d03ab51820 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -979,8 +979,8 @@ async fn extension_logical_plan() -> Result<()> {
}),
});
- let proto = to_substrait_plan(&ext_plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&ext_plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan1str = format!("{ext_plan}");
let plan2str = format!("{plan2}");
@@ -1081,8 +1081,8 @@ async fn roundtrip_repartition_roundrobin() -> Result<()>
{
partitioning_scheme: Partitioning::RoundRobinBatch(8),
});
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
assert_eq!(format!("{plan}"), format!("{plan2}"));
@@ -1098,8 +1098,8 @@ async fn roundtrip_repartition_hash() -> Result<()> {
partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8),
});
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
assert_eq!(format!("{plan}"), format!("{plan2}"));
@@ -1199,8 +1199,8 @@ async fn assert_expected_plan_unoptimized(
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_unoptimized_plan();
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
println!("{plan}");
println!("{plan2}");
@@ -1225,8 +1225,8 @@ async fn assert_expected_plan(
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
println!("{plan}");
@@ -1250,7 +1250,7 @@ async fn assert_expected_plan_substrait(
) -> Result<()> {
let ctx = create_context().await?;
- let plan = from_substrait_plan(&ctx, &substrait_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
let plan = ctx.state().optimize(&plan)?;
@@ -1265,7 +1265,7 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql:
&str) -> Result<()> {
let expected = ctx.sql(sql).await?.into_optimized_plan()?;
- let plan = from_substrait_plan(&ctx, &substrait_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
let plan = ctx.state().optimize(&plan)?;
@@ -1280,8 +1280,8 @@ async fn roundtrip_fill_na(sql: &str) -> Result<()> {
let ctx = create_context().await?;
let df = ctx.sql(sql).await?;
let plan = df.into_optimized_plan()?;
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
// Format plan string and replace all None's with 0
@@ -1301,12 +1301,12 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias:
&str) -> Result<()> {
let ctx = create_context().await?;
let df_a = ctx.sql(sql_with_alias).await?;
- let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?;
- let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?;
+ let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?,
&ctx.state())?;
+ let plan_with_alias = from_substrait_plan(&ctx.state(), &proto_a).await?;
let df = ctx.sql(sql_no_alias).await?;
- let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?;
- let plan = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?;
+ let plan = from_substrait_plan(&ctx.state(), &proto).await?;
println!("{plan_with_alias}");
println!("{plan}");
@@ -1323,8 +1323,8 @@ async fn roundtrip_logical_plan_with_ctx(
plan: LogicalPlan,
ctx: SessionContext,
) -> Result<Box<Plan>> {
- let proto = to_substrait_plan(&plan, &ctx)?;
- let plan2 = from_substrait_plan(&ctx, &proto).await?;
+ let proto = to_substrait_plan(&plan, &ctx.state())?;
+ let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
let plan2 = ctx.state().optimize(&plan2)?;
println!("{plan}");
diff --git a/datafusion/substrait/tests/cases/serialize.rs
b/datafusion/substrait/tests/cases/serialize.rs
index 54d55d1b6f..e28c633127 100644
--- a/datafusion/substrait/tests/cases/serialize.rs
+++ b/datafusion/substrait/tests/cases/serialize.rs
@@ -45,7 +45,7 @@ mod tests {
// Read substrait plan from file
let proto = serializer::deserialize(path).await?;
// Check plan equality
- let plan = from_substrait_plan(&ctx, &proto).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto).await?;
let plan_str_ref = format!("{plan_ref}");
let plan_str = format!("{plan}");
assert_eq!(plan_str_ref, plan_str);
@@ -60,7 +60,7 @@ mod tests {
let ctx = create_context().await?;
let table = provider_as_source(ctx.table_provider("data").await?);
let table_scan = LogicalPlanBuilder::scan("data", table,
None)?.build()?;
- let convert_result = to_substrait_plan(&table_scan, &ctx);
+ let convert_result = to_substrait_plan(&table_scan, &ctx.state());
assert!(convert_result.is_ok());
Ok(())
@@ -78,7 +78,9 @@ mod tests {
\n TableScan: data projection=[a, b]",
);
- let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+ let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
+ .as_ref()
+ .clone();
let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
@@ -121,7 +123,9 @@ mod tests {
\n TableScan: data projection=[a, b, c]",
);
- let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+ let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
+ .as_ref()
+ .clone();
let relation = plan.relations.first().unwrap().rel_type.as_ref();
let root_rel = match relation {
diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs
b/datafusion/substrait/tests/cases/substrait_validations.rs
index 5ae586afe5..c77bf1489f 100644
--- a/datafusion/substrait/tests/cases/substrait_validations.rs
+++ b/datafusion/substrait/tests/cases/substrait_validations.rs
@@ -65,7 +65,7 @@ mod tests {
vec![("a", DataType::Int32, false), ("b", DataType::Int32,
true)];
let ctx = generate_context_with_table("DATA", df_schema)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(
format!("{}", plan),
@@ -86,7 +86,7 @@ mod tests {
("c", DataType::Int32, false),
];
let ctx = generate_context_with_table("DATA", df_schema)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(
format!("{}", plan),
@@ -109,7 +109,7 @@ mod tests {
("b", DataType::Int32, false),
];
let ctx = generate_context_with_table("DATA", df_schema)?;
- let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+ let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
assert_eq!(
format!("{}", plan),
@@ -128,7 +128,7 @@ mod tests {
vec![("a", DataType::Int32, false), ("c", DataType::Int32,
true)];
let ctx = generate_context_with_table("DATA", df_schema)?;
- let res = from_substrait_plan(&ctx, &proto_plan).await;
+ let res = from_substrait_plan(&ctx.state(), &proto_plan).await;
assert!(res.is_err());
Ok(())
}
@@ -140,7 +140,7 @@ mod tests {
let ctx =
generate_context_with_table("DATA", vec![("a",
DataType::Date32, true)])?;
- let res = from_substrait_plan(&ctx, &proto_plan).await;
+ let res = from_substrait_plan(&ctx.state(), &proto_plan).await;
assert!(res.is_err());
Ok(())
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]