This is an automated email from the ASF dual-hosted git repository.

mneumann pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5ee524ec70 feat(substrait): replace SessionContext with a trait 
(#13343)
5ee524ec70 is described below

commit 5ee524ec70b1f10a078caca62954ce37b2dc3cc6
Author: Filippo Rossi <[email protected]>
AuthorDate: Wed Nov 20 18:11:18 2024 +0100

    feat(substrait): replace SessionContext with a trait (#13343)
    
    * feat(substrait): replace SessionContext with SessionState
    
    * feat(substrait): add logical plan context
    
    * chore(substrait): add apache header
    
    * docs: fix code in docs
    
    * docs(substrait): rename and document context
    
    * chore(substrait): context -> state
    
    * chore: fmt
---
 datafusion/core/src/execution/session_state.rs     |   4 +-
 datafusion/substrait/Cargo.toml                    |   1 +
 datafusion/substrait/src/lib.rs                    |   4 +-
 datafusion/substrait/src/logical_plan/consumer.rs  | 286 ++++++++++++---------
 datafusion/substrait/src/logical_plan/mod.rs       |   1 +
 datafusion/substrait/src/logical_plan/producer.rs  | 196 +++++++-------
 datafusion/substrait/src/logical_plan/state.rs     |  63 +++++
 datafusion/substrait/src/serializer.rs             |   2 +-
 .../substrait/tests/cases/consumer_integration.rs  |   2 +-
 .../substrait/tests/cases/emit_kind_tests.rs       |  12 +-
 datafusion/substrait/tests/cases/function_test.rs  |   2 +-
 datafusion/substrait/tests/cases/logical_plans.rs  |   6 +-
 .../tests/cases/roundtrip_logical_plan.rs          |  40 +--
 datafusion/substrait/tests/cases/serialize.rs      |  12 +-
 .../substrait/tests/cases/substrait_validations.rs |  10 +-
 15 files changed, 379 insertions(+), 262 deletions(-)

diff --git a/datafusion/core/src/execution/session_state.rs 
b/datafusion/core/src/execution/session_state.rs
index 9fc081dd53..e99cf82223 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -296,7 +296,9 @@ impl SessionState {
             .resolve(&catalog.default_catalog, &catalog.default_schema)
     }
 
-    pub(crate) fn schema_for_ref(
+    /// Retrieve the [`SchemaProvider`] for a specific [`TableReference`], if 
it
+    /// esists.
+    pub fn schema_for_ref(
         &self,
         table_ref: impl Into<TableReference>,
     ) -> datafusion_common::Result<Arc<dyn SchemaProvider>> {
diff --git a/datafusion/substrait/Cargo.toml b/datafusion/substrait/Cargo.toml
index 192fe26d6c..61cdf3e91e 100644
--- a/datafusion/substrait/Cargo.toml
+++ b/datafusion/substrait/Cargo.toml
@@ -34,6 +34,7 @@ workspace = true
 [dependencies]
 arrow-buffer = { workspace = true }
 async-recursion = "1.0"
+async-trait = { workspace = true }
 chrono = { workspace = true }
 datafusion = { workspace = true, default-features = true }
 itertools = { workspace = true }
diff --git a/datafusion/substrait/src/lib.rs b/datafusion/substrait/src/lib.rs
index a6f7c033f9..1389cac75b 100644
--- a/datafusion/substrait/src/lib.rs
+++ b/datafusion/substrait/src/lib.rs
@@ -64,10 +64,10 @@
 //!  let plan = df.into_optimized_plan()?;
 //!
 //!  // Convert the plan into a substrait (protobuf) Plan
-//!  let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, 
&ctx)?;
+//!  let substrait_plan = logical_plan::producer::to_substrait_plan(&plan, 
&ctx.state())?;
 //!
 //!  // Receive a substrait protobuf from somewhere, and turn it into a 
LogicalPlan
-//!  let logical_round_trip = 
logical_plan::consumer::from_substrait_plan(&ctx, &substrait_plan).await?;
+//!  let logical_round_trip = 
logical_plan::consumer::from_substrait_plan(&ctx.state(), 
&substrait_plan).await?;
 //!  let logical_round_trip = ctx.state().optimize(&logical_round_trip)?;
 //!  assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
 //! # Ok(())
diff --git a/datafusion/substrait/src/logical_plan/consumer.rs 
b/datafusion/substrait/src/logical_plan/consumer.rs
index 1cce228527..77e9eb81f5 100644
--- a/datafusion/substrait/src/logical_plan/consumer.rs
+++ b/datafusion/substrait/src/logical_plan/consumer.rs
@@ -26,7 +26,7 @@ use datafusion::common::{
     not_impl_err, plan_datafusion_err, substrait_datafusion_err, 
substrait_err, DFSchema,
     DFSchemaRef,
 };
-use datafusion::execution::FunctionRegistry;
+use datafusion::datasource::provider_as_source;
 use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};
 
 use datafusion::logical_expr::{
@@ -56,7 +56,6 @@ use crate::variation_const::{
 use datafusion::arrow::array::{new_empty_array, AsArray};
 use datafusion::arrow::temporal_conversions::NANOSECONDS;
 use datafusion::common::scalar::ScalarStructBuilder;
-use datafusion::dataframe::DataFrame;
 use datafusion::logical_expr::builder::project;
 use datafusion::logical_expr::expr::InList;
 use datafusion::logical_expr::{
@@ -66,9 +65,7 @@ use datafusion::logical_expr::{
 use datafusion::prelude::JoinType;
 use datafusion::sql::TableReference;
 use datafusion::{
-    error::Result,
-    logical_expr::utils::split_conjunction,
-    prelude::{Column, SessionContext},
+    error::Result, logical_expr::utils::split_conjunction, prelude::Column,
     scalar::ScalarValue,
 };
 use std::collections::HashSet;
@@ -102,6 +99,8 @@ use substrait::proto::{
 };
 use substrait::proto::{ExtendedExpression, FunctionArgument, SortField};
 
+use super::state::SubstraitPlanningState;
+
 // Substrait PrecisionTimestampTz indicates that the timestamp is relative to 
UTC, which
 // is the same as the expectation for any non-empty timezone in DF, so any 
non-empty timezone
 // results in correct points on the timeline, and we pick UTC as a reasonable 
default.
@@ -203,15 +202,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality(
 
 async fn union_rels(
     rels: &[Rel],
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     extensions: &Extensions,
     is_all: bool,
 ) -> Result<LogicalPlan> {
     let mut union_builder = Ok(LogicalPlanBuilder::from(
-        from_substrait_rel(ctx, &rels[0], extensions).await?,
+        from_substrait_rel(state, &rels[0], extensions).await?,
     ));
     for input in &rels[1..] {
-        let rel_plan = from_substrait_rel(ctx, input, extensions).await?;
+        let rel_plan = from_substrait_rel(state, input, extensions).await?;
 
         union_builder = if is_all {
             union_builder?.union(rel_plan)
@@ -224,16 +223,16 @@ async fn union_rels(
 
 async fn intersect_rels(
     rels: &[Rel],
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     extensions: &Extensions,
     is_all: bool,
 ) -> Result<LogicalPlan> {
-    let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;
+    let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
 
     for input in &rels[1..] {
         rel = LogicalPlanBuilder::intersect(
             rel,
-            from_substrait_rel(ctx, input, extensions).await?,
+            from_substrait_rel(state, input, extensions).await?,
             is_all,
         )?
     }
@@ -243,16 +242,16 @@ async fn intersect_rels(
 
 async fn except_rels(
     rels: &[Rel],
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     extensions: &Extensions,
     is_all: bool,
 ) -> Result<LogicalPlan> {
-    let mut rel = from_substrait_rel(ctx, &rels[0], extensions).await?;
+    let mut rel = from_substrait_rel(state, &rels[0], extensions).await?;
 
     for input in &rels[1..] {
         rel = LogicalPlanBuilder::except(
             rel,
-            from_substrait_rel(ctx, input, extensions).await?,
+            from_substrait_rel(state, input, extensions).await?,
             is_all,
         )?
     }
@@ -262,7 +261,7 @@ async fn except_rels(
 
 /// Convert Substrait Plan to DataFusion LogicalPlan
 pub async fn from_substrait_plan(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     plan: &Plan,
 ) -> Result<LogicalPlan> {
     // Register function extension
@@ -277,10 +276,10 @@ pub async fn from_substrait_plan(
             match plan.relations[0].rel_type.as_ref() {
                 Some(rt) => match rt {
                     plan_rel::RelType::Rel(rel) => {
-                        Ok(from_substrait_rel(ctx, rel, &extensions).await?)
+                        Ok(from_substrait_rel(state, rel, &extensions).await?)
                     },
                     plan_rel::RelType::Root(root) => {
-                        let plan = from_substrait_rel(ctx, 
root.input.as_ref().unwrap(), &extensions).await?;
+                        let plan = from_substrait_rel(state, 
root.input.as_ref().unwrap(), &extensions).await?;
                         if root.names.is_empty() {
                             // Backwards compatibility for plans missing names
                             return Ok(plan);
@@ -341,7 +340,7 @@ pub struct ExprContainer {
 /// between systems.  This is often useful for scenarios like pushdown where 
filter
 /// expressions need to be sent to remote systems.
 pub async fn from_substrait_extended_expr(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     extended_expr: &ExtendedExpression,
 ) -> Result<ExprContainer> {
     // Register function extension
@@ -370,7 +369,7 @@ pub async fn from_substrait_extended_expr(
             }
         }?;
         let expr =
-            from_substrait_rex(ctx, scalar_expr, &input_schema, 
&extensions).await?;
+            from_substrait_rex(state, scalar_expr, &input_schema, 
&extensions).await?;
         let (output_type, expected_nullability) =
             expr.data_type_and_nullable(&input_schema)?;
         let output_field = Field::new("", output_type, expected_nullability);
@@ -561,7 +560,7 @@ fn make_renamed_schema(
 #[allow(deprecated)]
 #[async_recursion]
 pub async fn from_substrait_rel(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     rel: &Rel,
     extensions: &Extensions,
 ) -> Result<LogicalPlan> {
@@ -569,7 +568,7 @@ pub async fn from_substrait_rel(
         Some(RelType::Project(p)) => {
             if let Some(input) = p.input.as_ref() {
                 let mut input = LogicalPlanBuilder::from(
-                    from_substrait_rel(ctx, input, extensions).await?,
+                    from_substrait_rel(state, input, extensions).await?,
                 );
                 let original_schema = input.schema().clone();
 
@@ -587,9 +586,13 @@ pub async fn from_substrait_rel(
 
                 let mut explicit_exprs: Vec<Expr> = vec![];
                 for expr in &p.expressions {
-                    let e =
-                        from_substrait_rex(ctx, expr, input.clone().schema(), 
extensions)
-                            .await?;
+                    let e = from_substrait_rex(
+                        state,
+                        expr,
+                        input.clone().schema(),
+                        extensions,
+                    )
+                    .await?;
                     // if the expression is WindowFunction, wrap in a Window 
relation
                     if let Expr::WindowFunction(_) = &e {
                         // Adding the same expression here and in the project 
below
@@ -617,11 +620,11 @@ pub async fn from_substrait_rel(
         Some(RelType::Filter(filter)) => {
             if let Some(input) = filter.input.as_ref() {
                 let input = LogicalPlanBuilder::from(
-                    from_substrait_rel(ctx, input, extensions).await?,
+                    from_substrait_rel(state, input, extensions).await?,
                 );
                 if let Some(condition) = filter.condition.as_ref() {
                     let expr =
-                        from_substrait_rex(ctx, condition, input.schema(), 
extensions)
+                        from_substrait_rex(state, condition, input.schema(), 
extensions)
                             .await?;
                     input.filter(expr)?.build()
                 } else {
@@ -634,7 +637,7 @@ pub async fn from_substrait_rel(
         Some(RelType::Fetch(fetch)) => {
             if let Some(input) = fetch.input.as_ref() {
                 let input = LogicalPlanBuilder::from(
-                    from_substrait_rel(ctx, input, extensions).await?,
+                    from_substrait_rel(state, input, extensions).await?,
                 );
                 let offset = fetch.offset as usize;
                 // -1 means that ALL records should be returned
@@ -651,10 +654,10 @@ pub async fn from_substrait_rel(
         Some(RelType::Sort(sort)) => {
             if let Some(input) = sort.input.as_ref() {
                 let input = LogicalPlanBuilder::from(
-                    from_substrait_rel(ctx, input, extensions).await?,
+                    from_substrait_rel(state, input, extensions).await?,
                 );
                 let sorts =
-                    from_substrait_sorts(ctx, &sort.sorts, input.schema(), 
extensions)
+                    from_substrait_sorts(state, &sort.sorts, input.schema(), 
extensions)
                         .await?;
                 input.sort(sorts)?.build()
             } else {
@@ -664,13 +667,13 @@ pub async fn from_substrait_rel(
         Some(RelType::Aggregate(agg)) => {
             if let Some(input) = agg.input.as_ref() {
                 let input = LogicalPlanBuilder::from(
-                    from_substrait_rel(ctx, input, extensions).await?,
+                    from_substrait_rel(state, input, extensions).await?,
                 );
                 let mut ref_group_exprs = vec![];
 
                 for e in &agg.grouping_expressions {
                     let x =
-                        from_substrait_rex(ctx, e, input.schema(), 
extensions).await?;
+                        from_substrait_rex(state, e, input.schema(), 
extensions).await?;
                     ref_group_exprs.push(x);
                 }
 
@@ -681,7 +684,7 @@ pub async fn from_substrait_rel(
                     1 => {
                         group_exprs.extend_from_slice(
                             &from_substrait_grouping(
-                                ctx,
+                                state,
                                 &agg.groupings[0],
                                 &ref_group_exprs,
                                 input.schema(),
@@ -694,7 +697,7 @@ pub async fn from_substrait_rel(
                         let mut grouping_sets = vec![];
                         for grouping in &agg.groupings {
                             let grouping_set = from_substrait_grouping(
-                                ctx,
+                                state,
                                 grouping,
                                 &ref_group_exprs,
                                 input.schema(),
@@ -716,7 +719,7 @@ pub async fn from_substrait_rel(
                 for m in &agg.measures {
                     let filter = match &m.filter {
                         Some(fil) => Some(Box::new(
-                            from_substrait_rex(ctx, fil, input.schema(), 
extensions)
+                            from_substrait_rex(state, fil, input.schema(), 
extensions)
                                 .await?,
                         )),
                         None => None,
@@ -739,7 +742,7 @@ pub async fn from_substrait_rel(
                             let order_by = if !f.sorts.is_empty() {
                                 Some(
                                     from_substrait_sorts(
-                                        ctx,
+                                        state,
                                         &f.sorts,
                                         input.schema(),
                                         extensions,
@@ -751,7 +754,7 @@ pub async fn from_substrait_rel(
                             };
 
                             from_substrait_agg_func(
-                                ctx,
+                                state,
                                 f,
                                 input.schema(),
                                 extensions,
@@ -780,10 +783,12 @@ pub async fn from_substrait_rel(
             }
 
             let left: LogicalPlanBuilder = LogicalPlanBuilder::from(
-                from_substrait_rel(ctx, join.left.as_ref().unwrap(), 
extensions).await?,
+                from_substrait_rel(state, join.left.as_ref().unwrap(), 
extensions)
+                    .await?,
             );
             let right = LogicalPlanBuilder::from(
-                from_substrait_rel(ctx, join.right.as_ref().unwrap(), 
extensions).await?,
+                from_substrait_rel(state, join.right.as_ref().unwrap(), 
extensions)
+                    .await?,
             );
             let (left, right) = requalify_sides_if_needed(left, right)?;
 
@@ -796,7 +801,7 @@ pub async fn from_substrait_rel(
             // Otherwise, build join with only the filter, without join keys
             match &join.expression.as_ref() {
                 Some(expr) => {
-                    let on = from_substrait_rex(ctx, expr, &in_join_schema, 
extensions)
+                    let on = from_substrait_rex(state, expr, &in_join_schema, 
extensions)
                         .await?;
                     // The join expression can contain both equal and 
non-equal ops.
                     // As of datafusion 31.0.0, the equal and non equal join 
conditions are in separate fields.
@@ -831,26 +836,44 @@ pub async fn from_substrait_rel(
         }
         Some(RelType::Cross(cross)) => {
             let left = LogicalPlanBuilder::from(
-                from_substrait_rel(ctx, cross.left.as_ref().unwrap(), 
extensions).await?,
+                from_substrait_rel(state, cross.left.as_ref().unwrap(), 
extensions)
+                    .await?,
             );
             let right = LogicalPlanBuilder::from(
-                from_substrait_rel(ctx, cross.right.as_ref().unwrap(), 
extensions)
+                from_substrait_rel(state, cross.right.as_ref().unwrap(), 
extensions)
                     .await?,
             );
             let (left, right) = requalify_sides_if_needed(left, right)?;
             left.cross_join(right.build()?)?.build()
         }
         Some(RelType::Read(read)) => {
-            fn read_with_schema(
-                df: DataFrame,
+            async fn read_with_schema(
+                state: &dyn SubstraitPlanningState,
+                table_ref: TableReference,
                 schema: DFSchema,
                 projection: &Option<MaskExpression>,
             ) -> Result<LogicalPlan> {
-                ensure_schema_compatability(df.schema().to_owned(), 
schema.clone())?;
+                let schema = schema.replace_qualifier(table_ref.clone());
+
+                let plan = {
+                    let provider = match state.table(&table_ref).await? {
+                        Some(ref provider) => Arc::clone(provider),
+                        _ => return plan_err!("No table named '{table_ref}'"),
+                    };
+
+                    LogicalPlanBuilder::scan(
+                        table_ref,
+                        provider_as_source(Arc::clone(&provider)),
+                        None,
+                    )?
+                    .build()?
+                };
+
+                ensure_schema_compatability(plan.schema(), schema.clone())?;
 
                 let schema = apply_masking(schema, projection)?;
 
-                apply_projection(df, schema)
+                apply_projection(plan, schema)
             }
 
             let named_struct = read.base_schema.as_ref().ok_or_else(|| {
@@ -879,12 +902,13 @@ pub async fn from_substrait_rel(
                         },
                     };
 
-                    let t = ctx.table(table_reference.clone()).await?;
-
-                    let substrait_schema =
-                        substrait_schema.replace_qualifier(table_reference);
-
-                    read_with_schema(t, substrait_schema, &read.projection)
+                    read_with_schema(
+                        state,
+                        table_reference,
+                        substrait_schema,
+                        &read.projection,
+                    )
+                    .await
                 }
                 Some(ReadType::VirtualTable(vt)) => {
                     if vt.values.is_empty() {
@@ -960,12 +984,14 @@ pub async fn from_substrait_rel(
                     let name = filename.unwrap();
                     // directly use unwrap here since we could determine it is 
a valid one
                     let table_reference = TableReference::Bare { table: 
name.into() };
-                    let t = ctx.table(table_reference.clone()).await?;
-
-                    let substrait_schema =
-                        substrait_schema.replace_qualifier(table_reference);
 
-                    read_with_schema(t, substrait_schema, &read.projection)
+                    read_with_schema(
+                        state,
+                        table_reference,
+                        substrait_schema,
+                        &read.projection,
+                    )
+                    .await
                 }
                 _ => {
                     not_impl_err!("Unsupported ReadType: {:?}", 
&read.as_ref().read_type)
@@ -979,31 +1005,31 @@ pub async fn from_substrait_rel(
                 } else {
                     match set_op {
                         set_rel::SetOp::UnionAll => {
-                            union_rels(&set.inputs, ctx, extensions, 
true).await
+                            union_rels(&set.inputs, state, extensions, 
true).await
                         }
                         set_rel::SetOp::UnionDistinct => {
-                            union_rels(&set.inputs, ctx, extensions, 
false).await
+                            union_rels(&set.inputs, state, extensions, 
false).await
                         }
                         set_rel::SetOp::IntersectionPrimary => {
                             LogicalPlanBuilder::intersect(
-                                from_substrait_rel(ctx, &set.inputs[0], 
extensions)
+                                from_substrait_rel(state, &set.inputs[0], 
extensions)
                                     .await?,
-                                union_rels(&set.inputs[1..], ctx, extensions, 
true)
+                                union_rels(&set.inputs[1..], state, 
extensions, true)
                                     .await?,
                                 false,
                             )
                         }
                         set_rel::SetOp::IntersectionMultiset => {
-                            intersect_rels(&set.inputs, ctx, extensions, 
false).await
+                            intersect_rels(&set.inputs, state, extensions, 
false).await
                         }
                         set_rel::SetOp::IntersectionMultisetAll => {
-                            intersect_rels(&set.inputs, ctx, extensions, 
true).await
+                            intersect_rels(&set.inputs, state, extensions, 
true).await
                         }
                         set_rel::SetOp::MinusPrimary => {
-                            except_rels(&set.inputs, ctx, extensions, 
false).await
+                            except_rels(&set.inputs, state, extensions, 
false).await
                         }
                         set_rel::SetOp::MinusPrimaryAll => {
-                            except_rels(&set.inputs, ctx, extensions, 
true).await
+                            except_rels(&set.inputs, state, extensions, 
true).await
                         }
                         _ => not_impl_err!("Unsupported set operator: 
{set_op:?}"),
                     }
@@ -1015,8 +1041,7 @@ pub async fn from_substrait_rel(
             let Some(ext_detail) = &extension.detail else {
                 return substrait_err!("Unexpected empty detail in 
ExtensionLeafRel");
             };
-            let plan = ctx
-                .state()
+            let plan = state
                 .serializer_registry()
                 .deserialize_logical_plan(&ext_detail.type_url, 
&ext_detail.value)?;
             Ok(LogicalPlan::Extension(Extension { node: plan }))
@@ -1025,8 +1050,7 @@ pub async fn from_substrait_rel(
             let Some(ext_detail) = &extension.detail else {
                 return substrait_err!("Unexpected empty detail in 
ExtensionSingleRel");
             };
-            let plan = ctx
-                .state()
+            let plan = state
                 .serializer_registry()
                 .deserialize_logical_plan(&ext_detail.type_url, 
&ext_detail.value)?;
             let Some(input_rel) = &extension.input else {
@@ -1034,7 +1058,7 @@ pub async fn from_substrait_rel(
                     "ExtensionSingleRel doesn't contains input rel. Try use 
ExtensionLeafRel instead"
                 );
             };
-            let input_plan = from_substrait_rel(ctx, input_rel, 
extensions).await?;
+            let input_plan = from_substrait_rel(state, input_rel, 
extensions).await?;
             let plan =
                 plan.with_exprs_and_inputs(plan.expressions(), 
vec![input_plan])?;
             Ok(LogicalPlan::Extension(Extension { node: plan }))
@@ -1043,13 +1067,12 @@ pub async fn from_substrait_rel(
             let Some(ext_detail) = &extension.detail else {
                 return substrait_err!("Unexpected empty detail in 
ExtensionSingleRel");
             };
-            let plan = ctx
-                .state()
+            let plan = state
                 .serializer_registry()
                 .deserialize_logical_plan(&ext_detail.type_url, 
&ext_detail.value)?;
             let mut inputs = Vec::with_capacity(extension.inputs.len());
             for input in &extension.inputs {
-                let input_plan = from_substrait_rel(ctx, input, 
extensions).await?;
+                let input_plan = from_substrait_rel(state, input, 
extensions).await?;
                 inputs.push(input_plan);
             }
             let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
@@ -1059,7 +1082,7 @@ pub async fn from_substrait_rel(
             let Some(input) = exchange.input.as_ref() else {
                 return substrait_err!("Unexpected empty input in ExchangeRel");
             };
-            let input = Arc::new(from_substrait_rel(ctx, input, 
extensions).await?);
+            let input = Arc::new(from_substrait_rel(state, input, 
extensions).await?);
 
             let Some(exchange_kind) = &exchange.exchange_kind else {
                 return substrait_err!("Unexpected empty input in ExchangeRel");
@@ -1237,7 +1260,7 @@ impl NameTracker {
 ///    DataFusion schema may have MORE fields, but not the other way around.
 /// 2. All fields are compatible. See [`ensure_field_compatability`] for 
details
 fn ensure_schema_compatability(
-    table_schema: DFSchema,
+    table_schema: &DFSchema,
     substrait_schema: DFSchema,
 ) -> Result<()> {
     substrait_schema
@@ -1253,16 +1276,19 @@ fn ensure_schema_compatability(
 
 /// This function returns a DataFrame with fields adjusted if necessary in the 
event that the
 /// Substrait schema is a subset of the DataFusion schema.
-fn apply_projection(table: DataFrame, substrait_schema: DFSchema) -> 
Result<LogicalPlan> {
-    let df_schema = table.schema().to_owned();
-
-    let t = table.into_unoptimized_plan();
+fn apply_projection(
+    plan: LogicalPlan,
+    substrait_schema: DFSchema,
+) -> Result<LogicalPlan> {
+    let df_schema = plan.schema();
 
     if df_schema.logically_equivalent_names_and_types(&substrait_schema) {
-        return Ok(t);
+        return Ok(plan);
     }
 
-    match t {
+    let df_schema = df_schema.to_owned();
+
+    match plan {
         LogicalPlan::TableScan(mut scan) => {
             let column_indices: Vec<usize> = substrait_schema
                 .strip_qualifiers()
@@ -1389,7 +1415,7 @@ fn from_substrait_jointype(join_type: i32) -> 
Result<JoinType> {
 
 /// Convert Substrait Sorts to DataFusion Exprs
 pub async fn from_substrait_sorts(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     substrait_sorts: &Vec<SortField>,
     input_schema: &DFSchema,
     extensions: &Extensions,
@@ -1397,7 +1423,7 @@ pub async fn from_substrait_sorts(
     let mut sorts: Vec<Sort> = vec![];
     for s in substrait_sorts {
         let expr =
-            from_substrait_rex(ctx, s.expr.as_ref().unwrap(), input_schema, 
extensions)
+            from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, 
extensions)
                 .await?;
         let asc_nullfirst = match &s.sort_kind {
             Some(k) => match k {
@@ -1439,14 +1465,15 @@ pub async fn from_substrait_sorts(
 
 /// Convert Substrait Expressions to DataFusion Exprs
 pub async fn from_substrait_rex_vec(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     exprs: &Vec<Expression>,
     input_schema: &DFSchema,
     extensions: &Extensions,
 ) -> Result<Vec<Expr>> {
     let mut expressions: Vec<Expr> = vec![];
     for expr in exprs {
-        let expression = from_substrait_rex(ctx, expr, input_schema, 
extensions).await?;
+        let expression =
+            from_substrait_rex(state, expr, input_schema, extensions).await?;
         expressions.push(expression);
     }
     Ok(expressions)
@@ -1454,7 +1481,7 @@ pub async fn from_substrait_rex_vec(
 
 /// Convert Substrait FunctionArguments to DataFusion Exprs
 pub async fn from_substrait_func_args(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     arguments: &Vec<FunctionArgument>,
     input_schema: &DFSchema,
     extensions: &Extensions,
@@ -1463,7 +1490,7 @@ pub async fn from_substrait_func_args(
     for arg in arguments {
         let arg_expr = match &arg.arg_type {
             Some(ArgType::Value(e)) => {
-                from_substrait_rex(ctx, e, input_schema, extensions).await
+                from_substrait_rex(state, e, input_schema, extensions).await
             }
             _ => not_impl_err!("Function argument non-Value type not 
supported"),
         };
@@ -1474,7 +1501,7 @@ pub async fn from_substrait_func_args(
 
 /// Convert Substrait AggregateFunction to DataFusion Expr
 pub async fn from_substrait_agg_func(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     f: &AggregateFunction,
     input_schema: &DFSchema,
     extensions: &Extensions,
@@ -1483,7 +1510,7 @@ pub async fn from_substrait_agg_func(
     distinct: bool,
 ) -> Result<Arc<Expr>> {
     let args =
-        from_substrait_func_args(ctx, &f.arguments, input_schema, 
extensions).await?;
+        from_substrait_func_args(state, &f.arguments, input_schema, 
extensions).await?;
 
     let Some(function_name) = extensions.functions.get(&f.function_reference) 
else {
         return plan_err!(
@@ -1494,7 +1521,7 @@ pub async fn from_substrait_agg_func(
 
     let function_name = substrait_fun_name(function_name);
     // try udaf first, then built-in aggr fn.
-    if let Ok(fun) = ctx.udaf(function_name) {
+    if let Ok(fun) = state.udaf(function_name) {
         // deal with situation that count(*) got no arguments
         let args = if fun.name() == "count" && args.is_empty() {
             vec![Expr::Literal(ScalarValue::Int64(Some(1)))]
@@ -1517,7 +1544,7 @@ pub async fn from_substrait_agg_func(
 /// Convert Substrait Rex to DataFusion Expr
 #[async_recursion]
 pub async fn from_substrait_rex(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     e: &Expression,
     input_schema: &DFSchema,
     extensions: &Extensions,
@@ -1528,11 +1555,11 @@ pub async fn from_substrait_rex(
             let substrait_list = s.options.as_ref();
             Ok(Expr::InList(InList {
                 expr: Box::new(
-                    from_substrait_rex(ctx, substrait_expr, input_schema, 
extensions)
+                    from_substrait_rex(state, substrait_expr, input_schema, 
extensions)
                         .await?,
                 ),
                 list: from_substrait_rex_vec(
-                    ctx,
+                    state,
                     substrait_list,
                     input_schema,
                     extensions,
@@ -1555,7 +1582,7 @@ pub async fn from_substrait_rex(
                     if if_expr.then.is_none() {
                         expr = Some(Box::new(
                             from_substrait_rex(
-                                ctx,
+                                state,
                                 if_expr.r#if.as_ref().unwrap(),
                                 input_schema,
                                 extensions,
@@ -1568,7 +1595,7 @@ pub async fn from_substrait_rex(
                 when_then_expr.push((
                     Box::new(
                         from_substrait_rex(
-                            ctx,
+                            state,
                             if_expr.r#if.as_ref().unwrap(),
                             input_schema,
                             extensions,
@@ -1577,7 +1604,7 @@ pub async fn from_substrait_rex(
                     ),
                     Box::new(
                         from_substrait_rex(
-                            ctx,
+                            state,
                             if_expr.then.as_ref().unwrap(),
                             input_schema,
                             extensions,
@@ -1589,7 +1616,7 @@ pub async fn from_substrait_rex(
             // Parse `else`
             let else_expr = match &if_then.r#else {
                 Some(e) => Some(Box::new(
-                    from_substrait_rex(ctx, e, input_schema, 
extensions).await?,
+                    from_substrait_rex(state, e, input_schema, 
extensions).await?,
                 )),
                 None => None,
             };
@@ -1609,12 +1636,12 @@ pub async fn from_substrait_rex(
             let fn_name = substrait_fun_name(fn_name);
 
             let args =
-                from_substrait_func_args(ctx, &f.arguments, input_schema, 
extensions)
+                from_substrait_func_args(state, &f.arguments, input_schema, 
extensions)
                     .await?;
 
             // try to first match the requested function into registered udfs, 
then built-in ops
             // and finally built-in expressions
-            if let Some(func) = ctx.state().scalar_functions().get(fn_name) {
+            if let Ok(func) = state.udf(fn_name) {
                 Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf(
                     func.to_owned(),
                     args,
@@ -1644,7 +1671,7 @@ pub async fn from_substrait_rex(
 
                 Ok(combined_expr)
             } else if let Some(builder) = 
BuiltinExprBuilder::try_from_name(fn_name) {
-                builder.build(ctx, f, input_schema, extensions).await
+                builder.build(state, f, input_schema, extensions).await
             } else {
                 not_impl_err!("Unsupported function name: {fn_name:?}")
             }
@@ -1657,7 +1684,7 @@ pub async fn from_substrait_rex(
             Some(output_type) => Ok(Expr::Cast(Cast::new(
                 Box::new(
                     from_substrait_rex(
-                        ctx,
+                        state,
                         cast.as_ref().input.as_ref().unwrap().as_ref(),
                         input_schema,
                         extensions,
@@ -1679,9 +1706,9 @@ pub async fn from_substrait_rex(
             let fn_name = substrait_fun_name(fn_name);
 
             // check udwf first, then udaf, then built-in window and aggregate 
functions
-            let fun = if let Ok(udwf) = ctx.udwf(fn_name) {
+            let fun = if let Ok(udwf) = state.udwf(fn_name) {
                 Ok(WindowFunctionDefinition::WindowUDF(udwf))
-            } else if let Ok(udaf) = ctx.udaf(fn_name) {
+            } else if let Ok(udaf) = state.udaf(fn_name) {
                 Ok(WindowFunctionDefinition::AggregateUDF(udaf))
             } else {
                 not_impl_err!(
@@ -1692,7 +1719,7 @@ pub async fn from_substrait_rex(
             }?;
 
             let order_by =
-                from_substrait_sorts(ctx, &window.sorts, input_schema, 
extensions)
+                from_substrait_sorts(state, &window.sorts, input_schema, 
extensions)
                     .await?;
 
             let bound_units =
@@ -1715,14 +1742,14 @@ pub async fn from_substrait_rex(
             Ok(Expr::WindowFunction(expr::WindowFunction {
                 fun,
                 args: from_substrait_func_args(
-                    ctx,
+                    state,
                     &window.arguments,
                     input_schema,
                     extensions,
                 )
                 .await?,
                 partition_by: from_substrait_rex_vec(
-                    ctx,
+                    state,
                     &window.partitions,
                     input_schema,
                     extensions,
@@ -1747,13 +1774,13 @@ pub async fn from_substrait_rex(
                         let haystack_expr = &in_predicate.haystack;
                         if let Some(haystack_expr) = haystack_expr {
                             let haystack_expr =
-                                from_substrait_rel(ctx, haystack_expr, 
extensions)
+                                from_substrait_rel(state, haystack_expr, 
extensions)
                                     .await?;
                             let outer_refs = haystack_expr.all_out_ref_exprs();
                             Ok(Expr::InSubquery(InSubquery {
                                 expr: Box::new(
                                     from_substrait_rex(
-                                        ctx,
+                                        state,
                                         needle_expr,
                                         input_schema,
                                         extensions,
@@ -1773,7 +1800,7 @@ pub async fn from_substrait_rex(
                 }
                 SubqueryType::Scalar(query) => {
                     let plan = from_substrait_rel(
-                        ctx,
+                        state,
                         &(query.input.clone()).unwrap_or_default(),
                         extensions,
                     )
@@ -1790,7 +1817,7 @@ pub async fn from_substrait_rex(
                         PredicateOp::Exists => {
                             let relation = &predicate.tuples;
                             let plan = from_substrait_rel(
-                                ctx,
+                                state,
                                 &relation.clone().unwrap_or_default(),
                                 extensions,
                             )
@@ -2772,7 +2799,7 @@ fn from_substrait_null(
 
 #[allow(deprecated)]
 async fn from_substrait_grouping(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     grouping: &Grouping,
     expressions: &[Expr],
     input_schema: &DFSchemaRef,
@@ -2781,7 +2808,7 @@ async fn from_substrait_grouping(
     let mut group_exprs = vec![];
     if !grouping.grouping_expressions.is_empty() {
         for e in &grouping.grouping_expressions {
-            let expr = from_substrait_rex(ctx, e, input_schema, 
extensions).await?;
+            let expr = from_substrait_rex(state, e, input_schema, 
extensions).await?;
             group_exprs.push(expr);
         }
         return Ok(group_exprs);
@@ -2834,23 +2861,29 @@ impl BuiltinExprBuilder {
 
     pub async fn build(
         self,
-        ctx: &SessionContext,
+        state: &dyn SubstraitPlanningState,
         f: &ScalarFunction,
         input_schema: &DFSchema,
         extensions: &Extensions,
     ) -> Result<Expr> {
         match self.expr_name.as_str() {
             "like" => {
-                Self::build_like_expr(ctx, false, f, input_schema, 
extensions).await
+                Self::build_like_expr(state, false, f, input_schema, 
extensions).await
             }
             "ilike" => {
-                Self::build_like_expr(ctx, true, f, input_schema, 
extensions).await
+                Self::build_like_expr(state, true, f, input_schema, 
extensions).await
             }
             "not" | "negative" | "negate" | "is_null" | "is_not_null" | 
"is_true"
             | "is_false" | "is_not_true" | "is_not_false" | "is_unknown"
             | "is_not_unknown" => {
-                Self::build_unary_expr(ctx, &self.expr_name, f, input_schema, 
extensions)
-                    .await
+                Self::build_unary_expr(
+                    state,
+                    &self.expr_name,
+                    f,
+                    input_schema,
+                    extensions,
+                )
+                .await
             }
             _ => {
                 not_impl_err!("Unsupported builtin expression: {}", 
self.expr_name)
@@ -2859,7 +2892,7 @@ impl BuiltinExprBuilder {
     }
 
     async fn build_unary_expr(
-        ctx: &SessionContext,
+        state: &dyn SubstraitPlanningState,
         fn_name: &str,
         f: &ScalarFunction,
         input_schema: &DFSchema,
@@ -2872,7 +2905,7 @@ impl BuiltinExprBuilder {
             return substrait_err!("Invalid arguments type for {fn_name} expr");
         };
         let arg =
-            from_substrait_rex(ctx, expr_substrait, input_schema, 
extensions).await?;
+            from_substrait_rex(state, expr_substrait, input_schema, 
extensions).await?;
         let arg = Box::new(arg);
 
         let expr = match fn_name {
@@ -2893,7 +2926,7 @@ impl BuiltinExprBuilder {
     }
 
     async fn build_like_expr(
-        ctx: &SessionContext,
+        state: &dyn SubstraitPlanningState,
         case_insensitive: bool,
         f: &ScalarFunction,
         input_schema: &DFSchema,
@@ -2908,12 +2941,13 @@ impl BuiltinExprBuilder {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
         let expr =
-            from_substrait_rex(ctx, expr_substrait, input_schema, 
extensions).await?;
+            from_substrait_rex(state, expr_substrait, input_schema, 
extensions).await?;
         let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type 
else {
             return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
         };
         let pattern =
-            from_substrait_rex(ctx, pattern_substrait, input_schema, 
extensions).await?;
+            from_substrait_rex(state, pattern_substrait, input_schema, 
extensions)
+                .await?;
 
         // Default case: escape character is Literal(Utf8(None))
         let escape_char = if f.arguments.len() == 3 {
@@ -2922,9 +2956,13 @@ impl BuiltinExprBuilder {
                 return substrait_err!("Invalid arguments type for `{fn_name}` 
expr");
             };
 
-            let escape_char_expr =
-                from_substrait_rex(ctx, escape_char_substrait, input_schema, 
extensions)
-                    .await?;
+            let escape_char_expr = from_substrait_rex(
+                state,
+                escape_char_substrait,
+                input_schema,
+                extensions,
+            )
+            .await?;
 
             match escape_char_expr {
                 Expr::Literal(ScalarValue::Utf8(escape_char_string)) => {
diff --git a/datafusion/substrait/src/logical_plan/mod.rs 
b/datafusion/substrait/src/logical_plan/mod.rs
index 6f8b8e493f..9e2fa9fa49 100644
--- a/datafusion/substrait/src/logical_plan/mod.rs
+++ b/datafusion/substrait/src/logical_plan/mod.rs
@@ -17,3 +17,4 @@
 
 pub mod consumer;
 pub mod producer;
+pub mod state;
diff --git a/datafusion/substrait/src/logical_plan/producer.rs 
b/datafusion/substrait/src/logical_plan/producer.rs
index 4d864e4334..29019dfd74 100644
--- a/datafusion/substrait/src/logical_plan/producer.rs
+++ b/datafusion/substrait/src/logical_plan/producer.rs
@@ -29,7 +29,7 @@ use datafusion::{
     arrow::datatypes::{DataType, TimeUnit},
     error::{DataFusionError, Result},
     logical_expr::{WindowFrame, WindowFrameBound},
-    prelude::{JoinType, SessionContext},
+    prelude::JoinType,
     scalar::ScalarValue,
 };
 
@@ -100,8 +100,13 @@ use substrait::{
     version,
 };
 
+use super::state::SubstraitPlanningState;
+
 /// Convert DataFusion LogicalPlan to Substrait Plan
-pub fn to_substrait_plan(plan: &LogicalPlan, ctx: &SessionContext) -> 
Result<Box<Plan>> {
+pub fn to_substrait_plan(
+    plan: &LogicalPlan,
+    state: &dyn SubstraitPlanningState,
+) -> Result<Box<Plan>> {
     let mut extensions = Extensions::default();
     // Parse relation nodes
     // Generate PlanRel(s)
@@ -113,7 +118,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: 
&SessionContext) -> Result<Box
 
     let plan_rels = vec![PlanRel {
         rel_type: Some(plan_rel::RelType::Root(RelRoot {
-            input: Some(*to_substrait_rel(&plan, ctx, &mut extensions)?),
+            input: Some(*to_substrait_rel(&plan, state, &mut extensions)?),
             names: to_substrait_named_struct(plan.schema())?.names,
         })),
     }];
@@ -144,7 +149,7 @@ pub fn to_substrait_plan(plan: &LogicalPlan, ctx: 
&SessionContext) -> Result<Box
 pub fn to_substrait_extended_expr(
     exprs: &[(&Expr, &Field)],
     schema: &DFSchemaRef,
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
 ) -> Result<Box<ExtendedExpression>> {
     let mut extensions = Extensions::default();
 
@@ -152,7 +157,7 @@ pub fn to_substrait_extended_expr(
         .iter()
         .map(|(expr, field)| {
             let substrait_expr = to_substrait_rex(
-                ctx,
+                state,
                 expr,
                 schema,
                 /*col_ref_offset=*/ 0,
@@ -183,7 +188,7 @@ pub fn to_substrait_extended_expr(
 #[allow(deprecated)]
 pub fn to_substrait_rel(
     plan: &LogicalPlan,
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     extensions: &mut Extensions,
 ) -> Result<Box<Rel>> {
     match plan {
@@ -284,7 +289,7 @@ pub fn to_substrait_rel(
             let expressions = p
                 .expr
                 .iter()
-                .map(|e| to_substrait_rex(ctx, e, p.input.schema(), 0, 
extensions))
+                .map(|e| to_substrait_rex(state, e, p.input.schema(), 0, 
extensions))
                 .collect::<Result<Vec<_>>>()?;
 
             let emit_kind = create_project_remapping(
@@ -300,16 +305,16 @@ pub fn to_substrait_rel(
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Project(Box::new(ProjectRel {
                     common: Some(common),
-                    input: Some(to_substrait_rel(p.input.as_ref(), ctx, 
extensions)?),
+                    input: Some(to_substrait_rel(p.input.as_ref(), state, 
extensions)?),
                     expressions,
                     advanced_extension: None,
                 }))),
             }))
         }
         LogicalPlan::Filter(filter) => {
-            let input = to_substrait_rel(filter.input.as_ref(), ctx, 
extensions)?;
+            let input = to_substrait_rel(filter.input.as_ref(), state, 
extensions)?;
             let filter_expr = to_substrait_rex(
-                ctx,
+                state,
                 &filter.predicate,
                 filter.input.schema(),
                 0,
@@ -325,7 +330,7 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Limit(limit) => {
-            let input = to_substrait_rel(limit.input.as_ref(), ctx, 
extensions)?;
+            let input = to_substrait_rel(limit.input.as_ref(), state, 
extensions)?;
             let FetchType::Literal(fetch) = limit.get_fetch_type()? else {
                 return not_impl_err!("Non-literal limit fetch");
             };
@@ -344,11 +349,11 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Sort(sort) => {
-            let input = to_substrait_rel(sort.input.as_ref(), ctx, 
extensions)?;
+            let input = to_substrait_rel(sort.input.as_ref(), state, 
extensions)?;
             let sort_fields = sort
                 .expr
                 .iter()
-                .map(|e| substrait_sort_field(ctx, e, sort.input.schema(), 
extensions))
+                .map(|e| substrait_sort_field(state, e, sort.input.schema(), 
extensions))
                 .collect::<Result<Vec<_>>>()?;
             Ok(Box::new(Rel {
                 rel_type: Some(RelType::Sort(Box::new(SortRel {
@@ -360,9 +365,9 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Aggregate(agg) => {
-            let input = to_substrait_rel(agg.input.as_ref(), ctx, extensions)?;
+            let input = to_substrait_rel(agg.input.as_ref(), state, 
extensions)?;
             let (grouping_expressions, groupings) = to_substrait_groupings(
-                ctx,
+                state,
                 &agg.group_expr,
                 agg.input.schema(),
                 extensions,
@@ -370,7 +375,9 @@ pub fn to_substrait_rel(
             let measures = agg
                 .aggr_expr
                 .iter()
-                .map(|e| to_substrait_agg_measure(ctx, e, agg.input.schema(), 
extensions))
+                .map(|e| {
+                    to_substrait_agg_measure(state, e, agg.input.schema(), 
extensions)
+                })
                 .collect::<Result<Vec<_>>>()?;
 
             Ok(Box::new(Rel {
@@ -386,7 +393,7 @@ pub fn to_substrait_rel(
         }
         LogicalPlan::Distinct(Distinct::All(plan)) => {
             // Use Substrait's AggregateRel with empty measures to represent 
`select distinct`
-            let input = to_substrait_rel(plan.as_ref(), ctx, extensions)?;
+            let input = to_substrait_rel(plan.as_ref(), state, extensions)?;
             // Get grouping keys from the input relation's number of output 
fields
             let grouping = (0..plan.schema().fields().len())
                 .map(substrait_field_ref)
@@ -407,8 +414,8 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Join(join) => {
-            let left = to_substrait_rel(join.left.as_ref(), ctx, extensions)?;
-            let right = to_substrait_rel(join.right.as_ref(), ctx, 
extensions)?;
+            let left = to_substrait_rel(join.left.as_ref(), state, 
extensions)?;
+            let right = to_substrait_rel(join.right.as_ref(), state, 
extensions)?;
             let join_type = to_substrait_jointype(join.join_type);
             // we only support basic joins so return an error for anything not 
yet supported
             match join.join_constraint {
@@ -421,7 +428,7 @@ pub fn to_substrait_rel(
             let in_join_schema = join.left.schema().join(join.right.schema())?;
             let join_filter = match &join.filter {
                 Some(filter) => Some(to_substrait_rex(
-                    ctx,
+                    state,
                     filter,
                     &Arc::new(in_join_schema),
                     0,
@@ -438,7 +445,7 @@ pub fn to_substrait_rel(
                 Operator::Eq
             };
             let join_on = to_substrait_join_expr(
-                ctx,
+                state,
                 &join.on,
                 eq_op,
                 join.left.schema(),
@@ -479,13 +486,13 @@ pub fn to_substrait_rel(
         LogicalPlan::SubqueryAlias(alias) => {
             // Do nothing if encounters SubqueryAlias
             // since there is no corresponding relation type in Substrait
-            to_substrait_rel(alias.input.as_ref(), ctx, extensions)
+            to_substrait_rel(alias.input.as_ref(), state, extensions)
         }
         LogicalPlan::Union(union) => {
             let input_rels = union
                 .inputs
                 .iter()
-                .map(|input| to_substrait_rel(input.as_ref(), ctx, extensions))
+                .map(|input| to_substrait_rel(input.as_ref(), state, 
extensions))
                 .collect::<Result<Vec<_>>>()?
                 .into_iter()
                 .map(|ptr| *ptr)
@@ -500,7 +507,7 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Window(window) => {
-            let input = to_substrait_rel(window.input.as_ref(), ctx, 
extensions)?;
+            let input = to_substrait_rel(window.input.as_ref(), state, 
extensions)?;
 
             // create a field reference for each input field
             let mut expressions = (0..window.input.schema().fields().len())
@@ -510,7 +517,7 @@ pub fn to_substrait_rel(
             // process and add each window function expression
             for expr in &window.window_expr {
                 expressions.push(to_substrait_rex(
-                    ctx,
+                    state,
                     expr,
                     window.input.schema(),
                     0,
@@ -539,7 +546,7 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Repartition(repartition) => {
-            let input = to_substrait_rel(repartition.input.as_ref(), ctx, 
extensions)?;
+            let input = to_substrait_rel(repartition.input.as_ref(), state, 
extensions)?;
             let partition_count = match repartition.partitioning_scheme {
                 Partitioning::RoundRobinBatch(num) => num,
                 Partitioning::Hash(_, num) => num,
@@ -585,8 +592,7 @@ pub fn to_substrait_rel(
             }))
         }
         LogicalPlan::Extension(extension_plan) => {
-            let extension_bytes = ctx
-                .state()
+            let extension_bytes = state
                 .serializer_registry()
                 .serialize_logical_plan(extension_plan.node.as_ref())?;
             let detail = ProtoAny {
@@ -597,7 +603,7 @@ pub fn to_substrait_rel(
                 .node
                 .inputs()
                 .into_iter()
-                .map(|plan| to_substrait_rel(plan, ctx, extensions))
+                .map(|plan| to_substrait_rel(plan, state, extensions))
                 .collect::<Result<Vec<_>>>()?;
             let rel_type = match inputs_rel.len() {
                 0 => RelType::ExtensionLeaf(ExtensionLeafRel {
@@ -687,7 +693,7 @@ fn to_substrait_named_struct(schema: &DFSchemaRef) -> 
Result<NamedStruct> {
 }
 
 fn to_substrait_join_expr(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     join_conditions: &Vec<(Expr, Expr)>,
     eq_op: Operator,
     left_schema: &DFSchemaRef,
@@ -698,10 +704,10 @@ fn to_substrait_join_expr(
     let mut exprs: Vec<Expression> = vec![];
     for (left, right) in join_conditions {
         // Parse left
-        let l = to_substrait_rex(ctx, left, left_schema, 0, extensions)?;
+        let l = to_substrait_rex(state, left, left_schema, 0, extensions)?;
         // Parse right
         let r = to_substrait_rex(
-            ctx,
+            state,
             right,
             right_schema,
             left_schema.fields().len(), // offset to return the correct index
@@ -770,7 +776,7 @@ pub fn operator_to_name(op: Operator) -> &'static str {
 
 #[allow(deprecated)]
 pub fn parse_flat_grouping_exprs(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     exprs: &[Expr],
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
@@ -780,7 +786,7 @@ pub fn parse_flat_grouping_exprs(
     let mut grouping_expressions = vec![];
 
     for e in exprs {
-        let rex = to_substrait_rex(ctx, e, schema, 0, extensions)?;
+        let rex = to_substrait_rex(state, e, schema, 0, extensions)?;
         grouping_expressions.push(rex.clone());
         ref_group_exprs.push(rex);
         expression_references.push((ref_group_exprs.len() - 1) as u32);
@@ -792,7 +798,7 @@ pub fn parse_flat_grouping_exprs(
 }
 
 pub fn to_substrait_groupings(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     exprs: &[Expr],
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
@@ -808,7 +814,7 @@ pub fn to_substrait_groupings(
                     .iter()
                     .map(|set| {
                         parse_flat_grouping_exprs(
-                            ctx,
+                            state,
                             set,
                             schema,
                             extensions,
@@ -826,7 +832,7 @@ pub fn to_substrait_groupings(
                         .rev()
                         .map(|set| {
                             parse_flat_grouping_exprs(
-                                ctx,
+                                state,
                                 set,
                                 schema,
                                 extensions,
@@ -837,7 +843,7 @@ pub fn to_substrait_groupings(
                 }
             },
             _ => Ok(vec![parse_flat_grouping_exprs(
-                ctx,
+                state,
                 exprs,
                 schema,
                 extensions,
@@ -845,7 +851,7 @@ pub fn to_substrait_groupings(
             )?]),
         },
         _ => Ok(vec![parse_flat_grouping_exprs(
-            ctx,
+            state,
             exprs,
             schema,
             extensions,
@@ -857,7 +863,7 @@ pub fn to_substrait_groupings(
 
 #[allow(deprecated)]
 pub fn to_substrait_agg_measure(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     expr: &Expr,
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
@@ -865,13 +871,13 @@ pub fn to_substrait_agg_measure(
     match expr {
         Expr::AggregateFunction(expr::AggregateFunction { func, args, 
distinct, filter, order_by, null_treatment: _, }) => {
                     let sorts = if let Some(order_by) = order_by {
-                        order_by.iter().map(|expr| 
to_substrait_sort_field(ctx, expr, schema, 
extensions)).collect::<Result<Vec<_>>>()?
+                        order_by.iter().map(|expr| 
to_substrait_sort_field(state, expr, schema, 
extensions)).collect::<Result<Vec<_>>>()?
                     } else {
                         vec![]
                     };
                     let mut arguments: Vec<FunctionArgument> = vec![];
                     for arg in args {
-                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(ctx, arg, schema, 0, extensions)?)) });
+                        arguments.push(FunctionArgument { arg_type: 
Some(ArgType::Value(to_substrait_rex(state, arg, schema, 0, extensions)?)) });
                     }
                     let function_anchor = 
extensions.register_function(func.name().to_string());
                     Ok(Measure {
@@ -889,14 +895,14 @@ pub fn to_substrait_agg_measure(
                             options: vec![],
                         }),
                         filter: match filter {
-                            Some(f) => Some(to_substrait_rex(ctx, f, schema, 
0, extensions)?),
+                            Some(f) => Some(to_substrait_rex(state, f, schema, 
0, extensions)?),
                             None => None
                         }
                     })
 
         }
         Expr::Alias(Alias{expr,..})=> {
-            to_substrait_agg_measure(ctx, expr, schema, extensions)
+            to_substrait_agg_measure(state, expr, schema, extensions)
         }
         _ => internal_err!(
             "Expression must be compatible with aggregation. Unsupported 
expression: {:?}. ExpressionType: {:?}",
@@ -908,7 +914,7 @@ pub fn to_substrait_agg_measure(
 
 /// Converts sort expression to corresponding substrait `SortField`
 fn to_substrait_sort_field(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     sort: &Sort,
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
@@ -920,7 +926,7 @@ fn to_substrait_sort_field(
         (false, false) => SortDirection::DescNullsLast,
     };
     Ok(SortField {
-        expr: Some(to_substrait_rex(ctx, &sort.expr, schema, 0, extensions)?),
+        expr: Some(to_substrait_rex(state, &sort.expr, schema, 0, 
extensions)?),
         sort_kind: Some(SortKind::Direction(sort_kind.into())),
     })
 }
@@ -977,7 +983,7 @@ pub fn make_binary_op_scalar_func(
 /// * `extensions` - Substrait extension info. Contains registered function 
information
 #[allow(deprecated)]
 pub fn to_substrait_rex(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     expr: &Expr,
     schema: &DFSchemaRef,
     col_ref_offset: usize,
@@ -991,10 +997,10 @@ pub fn to_substrait_rex(
         }) => {
             let substrait_list = list
                 .iter()
-                .map(|x| to_substrait_rex(ctx, x, schema, col_ref_offset, 
extensions))
+                .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, 
extensions))
                 .collect::<Result<Vec<Expression>>>()?;
             let substrait_expr =
-                to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extensions)?;
+                to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
 
             let substrait_or_list = Expression {
                 rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList 
{
@@ -1026,7 +1032,7 @@ pub fn to_substrait_rex(
             for arg in &fun.args {
                 arguments.push(FunctionArgument {
                     arg_type: Some(ArgType::Value(to_substrait_rex(
-                        ctx,
+                        state,
                         arg,
                         schema,
                         col_ref_offset,
@@ -1055,11 +1061,11 @@ pub fn to_substrait_rex(
             if *negated {
                 // `expr NOT BETWEEN low AND high` can be translated into 
(expr < low OR high < expr)
                 let substrait_expr =
-                    to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
                 let substrait_low =
-                    to_substrait_rex(ctx, low, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, low, schema, col_ref_offset, 
extensions)?;
                 let substrait_high =
-                    to_substrait_rex(ctx, high, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, high, schema, col_ref_offset, 
extensions)?;
 
                 let l_expr = make_binary_op_scalar_func(
                     &substrait_expr,
@@ -1083,11 +1089,11 @@ pub fn to_substrait_rex(
             } else {
                 // `expr BETWEEN low AND high` can be translated into (low <= 
expr AND expr <= high)
                 let substrait_expr =
-                    to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
                 let substrait_low =
-                    to_substrait_rex(ctx, low, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, low, schema, col_ref_offset, 
extensions)?;
                 let substrait_high =
-                    to_substrait_rex(ctx, high, schema, col_ref_offset, 
extensions)?;
+                    to_substrait_rex(state, high, schema, col_ref_offset, 
extensions)?;
 
                 let l_expr = make_binary_op_scalar_func(
                     &substrait_low,
@@ -1115,8 +1121,8 @@ pub fn to_substrait_rex(
             substrait_field_ref(index + col_ref_offset)
         }
         Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-            let l = to_substrait_rex(ctx, left, schema, col_ref_offset, 
extensions)?;
-            let r = to_substrait_rex(ctx, right, schema, col_ref_offset, 
extensions)?;
+            let l = to_substrait_rex(state, left, schema, col_ref_offset, 
extensions)?;
+            let r = to_substrait_rex(state, right, schema, col_ref_offset, 
extensions)?;
 
             Ok(make_binary_op_scalar_func(&l, &r, *op, extensions))
         }
@@ -1131,7 +1137,7 @@ pub fn to_substrait_rex(
                 // Base expression exists
                 ifs.push(IfClause {
                     r#if: Some(to_substrait_rex(
-                        ctx,
+                        state,
                         e,
                         schema,
                         col_ref_offset,
@@ -1144,14 +1150,14 @@ pub fn to_substrait_rex(
             for (r#if, then) in when_then_expr {
                 ifs.push(IfClause {
                     r#if: Some(to_substrait_rex(
-                        ctx,
+                        state,
                         r#if,
                         schema,
                         col_ref_offset,
                         extensions,
                     )?),
                     then: Some(to_substrait_rex(
-                        ctx,
+                        state,
                         then,
                         schema,
                         col_ref_offset,
@@ -1163,7 +1169,7 @@ pub fn to_substrait_rex(
             // Parse outer `else`
             let r#else: Option<Box<Expression>> = match else_expr {
                 Some(e) => Some(Box::new(to_substrait_rex(
-                    ctx,
+                    state,
                     e,
                     schema,
                     col_ref_offset,
@@ -1182,7 +1188,7 @@ pub fn to_substrait_rex(
                     substrait::proto::expression::Cast {
                         r#type: Some(to_substrait_type(data_type, true)?),
                         input: Some(Box::new(to_substrait_rex(
-                            ctx,
+                            state,
                             expr,
                             schema,
                             col_ref_offset,
@@ -1195,7 +1201,7 @@ pub fn to_substrait_rex(
         }
         Expr::Literal(value) => to_substrait_literal_expr(value, extensions),
         Expr::Alias(Alias { expr, .. }) => {
-            to_substrait_rex(ctx, expr, schema, col_ref_offset, extensions)
+            to_substrait_rex(state, expr, schema, col_ref_offset, extensions)
         }
         Expr::WindowFunction(WindowFunction {
             fun,
@@ -1212,7 +1218,7 @@ pub fn to_substrait_rex(
             for arg in args {
                 arguments.push(FunctionArgument {
                     arg_type: Some(ArgType::Value(to_substrait_rex(
-                        ctx,
+                        state,
                         arg,
                         schema,
                         col_ref_offset,
@@ -1223,12 +1229,12 @@ pub fn to_substrait_rex(
             // partition by expressions
             let partition_by = partition_by
                 .iter()
-                .map(|e| to_substrait_rex(ctx, e, schema, col_ref_offset, 
extensions))
+                .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, 
extensions))
                 .collect::<Result<Vec<_>>>()?;
             // order by expressions
             let order_by = order_by
                 .iter()
-                .map(|e| substrait_sort_field(ctx, e, schema, extensions))
+                .map(|e| substrait_sort_field(state, e, schema, extensions))
                 .collect::<Result<Vec<_>>>()?;
             // window frame
             let bounds = to_substrait_bounds(window_frame)?;
@@ -1249,7 +1255,7 @@ pub fn to_substrait_rex(
             escape_char,
             case_insensitive,
         }) => make_substrait_like_expr(
-            ctx,
+            state,
             *case_insensitive,
             *negated,
             expr,
@@ -1265,10 +1271,10 @@ pub fn to_substrait_rex(
             negated,
         }) => {
             let substrait_expr =
-                to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extensions)?;
+                to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
 
             let subquery_plan =
-                to_substrait_rel(subquery.subquery.as_ref(), ctx, extensions)?;
+                to_substrait_rel(subquery.subquery.as_ref(), state, 
extensions)?;
 
             let substrait_subquery = Expression {
                 rex_type: Some(RexType::Subquery(Box::new(Subquery {
@@ -1301,7 +1307,7 @@ pub fn to_substrait_rex(
             }
         }
         Expr::Not(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "not",
             arg,
             schema,
@@ -1309,7 +1315,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsNull(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_null",
             arg,
             schema,
@@ -1317,7 +1323,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsNotNull(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_not_null",
             arg,
             schema,
@@ -1325,7 +1331,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsTrue(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_true",
             arg,
             schema,
@@ -1333,7 +1339,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsFalse(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_false",
             arg,
             schema,
@@ -1341,7 +1347,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsUnknown(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_unknown",
             arg,
             schema,
@@ -1349,7 +1355,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsNotTrue(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_not_true",
             arg,
             schema,
@@ -1357,7 +1363,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsNotFalse(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_not_false",
             arg,
             schema,
@@ -1365,7 +1371,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::IsNotUnknown(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "is_not_unknown",
             arg,
             schema,
@@ -1373,7 +1379,7 @@ pub fn to_substrait_rex(
             extensions,
         ),
         Expr::Negative(arg) => to_substrait_unary_scalar_fn(
-            ctx,
+            state,
             "negate",
             arg,
             schema,
@@ -1674,7 +1680,7 @@ fn make_substrait_window_function(
 #[allow(deprecated)]
 #[allow(clippy::too_many_arguments)]
 fn make_substrait_like_expr(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     ignore_case: bool,
     negated: bool,
     expr: &Expr,
@@ -1689,8 +1695,8 @@ fn make_substrait_like_expr(
     } else {
         extensions.register_function("like".to_string())
     };
-    let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, 
extensions)?;
-    let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, 
extensions)?;
+    let expr = to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
+    let pattern = to_substrait_rex(state, pattern, schema, col_ref_offset, 
extensions)?;
     let escape_char = to_substrait_literal_expr(
         &ScalarValue::Utf8(escape_char.map(|c| c.to_string())),
         extensions,
@@ -2088,7 +2094,7 @@ fn to_substrait_literal_expr(
 
 /// Util to generate substrait [RexType::ScalarFunction] with one argument
 fn to_substrait_unary_scalar_fn(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     fn_name: &str,
     arg: &Expr,
     schema: &DFSchemaRef,
@@ -2096,7 +2102,8 @@ fn to_substrait_unary_scalar_fn(
     extensions: &mut Extensions,
 ) -> Result<Expression> {
     let function_anchor = extensions.register_function(fn_name.to_string());
-    let substrait_expr = to_substrait_rex(ctx, arg, schema, col_ref_offset, 
extensions)?;
+    let substrait_expr =
+        to_substrait_rex(state, arg, schema, col_ref_offset, extensions)?;
 
     Ok(Expression {
         rex_type: Some(RexType::ScalarFunction(ScalarFunction {
@@ -2137,7 +2144,7 @@ fn try_to_substrait_field_reference(
 }
 
 fn substrait_sort_field(
-    ctx: &SessionContext,
+    state: &dyn SubstraitPlanningState,
     sort: &Sort,
     schema: &DFSchemaRef,
     extensions: &mut Extensions,
@@ -2147,7 +2154,7 @@ fn substrait_sort_field(
         asc,
         nulls_first,
     } = sort;
-    let e = to_substrait_rex(ctx, expr, schema, 0, extensions)?;
+    let e = to_substrait_rex(state, expr, schema, 0, extensions)?;
     let d = match (asc, nulls_first) {
         (true, true) => SortDirection::AscNullsFirst,
         (true, false) => SortDirection::AscNullsLast,
@@ -2190,6 +2197,7 @@ mod test {
     use datafusion::arrow::datatypes::{Field, Fields, Schema};
     use datafusion::common::scalar::ScalarStructBuilder;
     use datafusion::common::DFSchema;
+    use datafusion::execution::SessionStateBuilder;
 
     #[test]
     fn round_trip_literals() -> Result<()> {
@@ -2433,15 +2441,15 @@ mod test {
 
     #[tokio::test]
     async fn extended_expressions() -> Result<()> {
-        let ctx = SessionContext::new();
+        let state = SessionStateBuilder::default().build();
 
         // One expression, empty input schema
         let expr = Expr::Literal(ScalarValue::Int32(Some(42)));
         let field = Field::new("out", DataType::Int32, false);
         let empty_schema = DFSchemaRef::new(DFSchema::empty());
         let substrait =
-            to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, 
&ctx)?;
-        let roundtrip_expr = from_substrait_extended_expr(&ctx, 
&substrait).await?;
+            to_substrait_extended_expr(&[(&expr, &field)], &empty_schema, 
&state)?;
+        let roundtrip_expr = from_substrait_extended_expr(&state, 
&substrait).await?;
 
         assert_eq!(roundtrip_expr.input_schema, empty_schema);
         assert_eq!(roundtrip_expr.exprs.len(), 1);
@@ -2463,9 +2471,9 @@ mod test {
         let substrait = to_substrait_extended_expr(
             &[(&expr1, &out1), (&expr2, &out2)],
             &input_schema,
-            &ctx,
+            &state,
         )?;
-        let roundtrip_expr = from_substrait_extended_expr(&ctx, 
&substrait).await?;
+        let roundtrip_expr = from_substrait_extended_expr(&state, 
&substrait).await?;
 
         assert_eq!(roundtrip_expr.input_schema, input_schema);
         assert_eq!(roundtrip_expr.exprs.len(), 2);
@@ -2485,14 +2493,14 @@ mod test {
 
     #[tokio::test]
     async fn invalid_extended_expression() {
-        let ctx = SessionContext::new();
+        let state = SessionStateBuilder::default().build();
 
         // Not ok if input schema is missing field referenced by expr
         let expr = Expr::Column("missing".into());
         let field = Field::new("out", DataType::Int32, false);
         let empty_schema = DFSchemaRef::new(DFSchema::empty());
 
-        let err = to_substrait_extended_expr(&[(&expr, &field)], 
&empty_schema, &ctx);
+        let err = to_substrait_extended_expr(&[(&expr, &field)], 
&empty_schema, &state);
 
         assert!(matches!(err, Err(DataFusionError::SchemaError(_, _))));
     }
diff --git a/datafusion/substrait/src/logical_plan/state.rs 
b/datafusion/substrait/src/logical_plan/state.rs
new file mode 100644
index 0000000000..0bd749c110
--- /dev/null
+++ b/datafusion/substrait/src/logical_plan/state.rs
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use datafusion::{
+    catalog::TableProvider,
+    error::{DataFusionError, Result},
+    execution::{registry::SerializerRegistry, FunctionRegistry, SessionState},
+    sql::TableReference,
+};
+
+/// This trait provides the context needed to transform a substrait plan into a
+/// [`datafusion::logical_expr::LogicalPlan`] (via 
[`super::consumer::from_substrait_plan`])
+/// and back again into a substrait plan (via 
[`super::producer::to_substrait_plan`]).
+///
+/// The context is declared as a trait to decouple the substrait plan encoder /
+/// decoder from the [`SessionState`], potentially allowing users to define
+/// their own slimmer context just for serializing and deserializing substrait.
+///
+/// [`SessionState`] implements this trait.
+#[async_trait]
+pub trait SubstraitPlanningState: Sync + Send + FunctionRegistry {
+    /// Return [SerializerRegistry] for extensions
+    fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry>;
+
+    async fn table(
+        &self,
+        reference: &TableReference,
+    ) -> Result<Option<Arc<dyn TableProvider>>>;
+}
+
+#[async_trait]
+impl SubstraitPlanningState for SessionState {
+    fn serializer_registry(&self) -> &Arc<dyn SerializerRegistry> {
+        self.serializer_registry()
+    }
+
+    async fn table(
+        &self,
+        reference: &TableReference,
+    ) -> Result<Option<Arc<dyn TableProvider>>, DataFusionError> {
+        let table = reference.table().to_string();
+        let schema = self.schema_for_ref(reference.clone())?;
+        let table_provider = schema.table(&table).await?;
+        Ok(table_provider)
+    }
+}
diff --git a/datafusion/substrait/src/serializer.rs 
b/datafusion/substrait/src/serializer.rs
index 6b81e33dfc..4278671777 100644
--- a/datafusion/substrait/src/serializer.rs
+++ b/datafusion/substrait/src/serializer.rs
@@ -38,7 +38,7 @@ pub async fn serialize(sql: &str, ctx: &SessionContext, path: 
&str) -> Result<()
 pub async fn serialize_bytes(sql: &str, ctx: &SessionContext) -> 
Result<Vec<u8>> {
     let df = ctx.sql(sql).await?;
     let plan = df.into_optimized_plan()?;
-    let proto = producer::to_substrait_plan(&plan, ctx)?;
+    let proto = producer::to_substrait_plan(&plan, &ctx.state())?;
 
     let mut protobuf_out = Vec::<u8>::new();
     proto.encode(&mut protobuf_out).map_err(|e| {
diff --git a/datafusion/substrait/tests/cases/consumer_integration.rs 
b/datafusion/substrait/tests/cases/consumer_integration.rs
index bc38ef8297..219f656bb4 100644
--- a/datafusion/substrait/tests/cases/consumer_integration.rs
+++ b/datafusion/substrait/tests/cases/consumer_integration.rs
@@ -41,7 +41,7 @@ mod tests {
         .expect("failed to parse json");
 
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
-        let plan = from_substrait_plan(&ctx, &proto).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto).await?;
         Ok(format!("{}", plan))
     }
 
diff --git a/datafusion/substrait/tests/cases/emit_kind_tests.rs 
b/datafusion/substrait/tests/cases/emit_kind_tests.rs
index ac66177ed7..08537d0d11 100644
--- a/datafusion/substrait/tests/cases/emit_kind_tests.rs
+++ b/datafusion/substrait/tests/cases/emit_kind_tests.rs
@@ -33,7 +33,7 @@ mod tests {
             
"tests/testdata/test_plans/emit_kind/direct_on_project.substrait.json",
         );
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         let plan_str = format!("{}", plan);
 
@@ -51,7 +51,7 @@ mod tests {
             
"tests/testdata/test_plans/emit_kind/emit_on_filter.substrait.json",
         );
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         let plan_str = format!("{}", plan);
 
@@ -91,8 +91,8 @@ mod tests {
              \n  TableScan: data"
         );
 
-        let proto = to_substrait_plan(&plan, &ctx)?;
-        let plan2 = from_substrait_plan(&ctx, &proto).await?;
+        let proto = to_substrait_plan(&plan, &ctx.state())?;
+        let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
         // note how the Projections are not flattened
         assert_eq!(
             format!("{}", plan2),
@@ -115,8 +115,8 @@ mod tests {
              \n  TableScan: data"
         );
 
-        let proto = to_substrait_plan(&plan, &ctx)?;
-        let plan2 = from_substrait_plan(&ctx, &proto).await?;
+        let proto = to_substrait_plan(&plan, &ctx.state())?;
+        let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
 
         let plan1str = format!("{plan}");
         let plan2str = format!("{plan2}");
diff --git a/datafusion/substrait/tests/cases/function_test.rs 
b/datafusion/substrait/tests/cases/function_test.rs
index b136b0af19..0438084561 100644
--- a/datafusion/substrait/tests/cases/function_test.rs
+++ b/datafusion/substrait/tests/cases/function_test.rs
@@ -29,7 +29,7 @@ mod tests {
     async fn contains_function_test() -> Result<()> {
         let proto_plan = 
read_json("tests/testdata/contains_plan.substrait.json");
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         let plan_str = format!("{}", plan);
 
diff --git a/datafusion/substrait/tests/cases/logical_plans.rs 
b/datafusion/substrait/tests/cases/logical_plans.rs
index f4e34af35d..65f404bbda 100644
--- a/datafusion/substrait/tests/cases/logical_plans.rs
+++ b/datafusion/substrait/tests/cases/logical_plans.rs
@@ -38,7 +38,7 @@ mod tests {
         let proto_plan =
             
read_json("tests/testdata/test_plans/select_not_bool.substrait.json");
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         assert_eq!(
             format!("{}", plan),
@@ -63,7 +63,7 @@ mod tests {
         let proto_plan =
             
read_json("tests/testdata/test_plans/select_window.substrait.json");
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         assert_eq!(
             format!("{}", plan),
@@ -82,7 +82,7 @@ mod tests {
         let proto_plan =
             
read_json("tests/testdata/test_plans/non_nullable_lists.substrait.json");
         let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto_plan)?;
-        let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
         assert_eq!(format!("{}", &plan), "Values: (List([1, 2]))");
 
diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs 
b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
index d4e2d48885..d03ab51820 100644
--- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
+++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
@@ -979,8 +979,8 @@ async fn extension_logical_plan() -> Result<()> {
         }),
     });
 
-    let proto = to_substrait_plan(&ext_plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&ext_plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
 
     let plan1str = format!("{ext_plan}");
     let plan2str = format!("{plan2}");
@@ -1081,8 +1081,8 @@ async fn roundtrip_repartition_roundrobin() -> Result<()> 
{
         partitioning_scheme: Partitioning::RoundRobinBatch(8),
     });
 
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     assert_eq!(format!("{plan}"), format!("{plan2}"));
@@ -1098,8 +1098,8 @@ async fn roundtrip_repartition_hash() -> Result<()> {
         partitioning_scheme: Partitioning::Hash(vec![col("data.a")], 8),
     });
 
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     assert_eq!(format!("{plan}"), format!("{plan2}"));
@@ -1199,8 +1199,8 @@ async fn assert_expected_plan_unoptimized(
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;
     let plan = df.into_unoptimized_plan();
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
 
     println!("{plan}");
     println!("{plan2}");
@@ -1225,8 +1225,8 @@ async fn assert_expected_plan(
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;
     let plan = df.into_optimized_plan()?;
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     println!("{plan}");
@@ -1250,7 +1250,7 @@ async fn assert_expected_plan_substrait(
 ) -> Result<()> {
     let ctx = create_context().await?;
 
-    let plan = from_substrait_plan(&ctx, &substrait_plan).await?;
+    let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
 
     let plan = ctx.state().optimize(&plan)?;
 
@@ -1265,7 +1265,7 @@ async fn assert_substrait_sql(substrait_plan: Plan, sql: 
&str) -> Result<()> {
 
     let expected = ctx.sql(sql).await?.into_optimized_plan()?;
 
-    let plan = from_substrait_plan(&ctx, &substrait_plan).await?;
+    let plan = from_substrait_plan(&ctx.state(), &substrait_plan).await?;
 
     let plan = ctx.state().optimize(&plan)?;
 
@@ -1280,8 +1280,8 @@ async fn roundtrip_fill_na(sql: &str) -> Result<()> {
     let ctx = create_context().await?;
     let df = ctx.sql(sql).await?;
     let plan = df.into_optimized_plan()?;
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     // Format plan string and replace all None's with 0
@@ -1301,12 +1301,12 @@ async fn test_alias(sql_with_alias: &str, sql_no_alias: 
&str) -> Result<()> {
     let ctx = create_context().await?;
 
     let df_a = ctx.sql(sql_with_alias).await?;
-    let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, &ctx)?;
-    let plan_with_alias = from_substrait_plan(&ctx, &proto_a).await?;
+    let proto_a = to_substrait_plan(&df_a.into_optimized_plan()?, 
&ctx.state())?;
+    let plan_with_alias = from_substrait_plan(&ctx.state(), &proto_a).await?;
 
     let df = ctx.sql(sql_no_alias).await?;
-    let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx)?;
-    let plan = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&df.into_optimized_plan()?, &ctx.state())?;
+    let plan = from_substrait_plan(&ctx.state(), &proto).await?;
 
     println!("{plan_with_alias}");
     println!("{plan}");
@@ -1323,8 +1323,8 @@ async fn roundtrip_logical_plan_with_ctx(
     plan: LogicalPlan,
     ctx: SessionContext,
 ) -> Result<Box<Plan>> {
-    let proto = to_substrait_plan(&plan, &ctx)?;
-    let plan2 = from_substrait_plan(&ctx, &proto).await?;
+    let proto = to_substrait_plan(&plan, &ctx.state())?;
+    let plan2 = from_substrait_plan(&ctx.state(), &proto).await?;
     let plan2 = ctx.state().optimize(&plan2)?;
 
     println!("{plan}");
diff --git a/datafusion/substrait/tests/cases/serialize.rs 
b/datafusion/substrait/tests/cases/serialize.rs
index 54d55d1b6f..e28c633127 100644
--- a/datafusion/substrait/tests/cases/serialize.rs
+++ b/datafusion/substrait/tests/cases/serialize.rs
@@ -45,7 +45,7 @@ mod tests {
         // Read substrait plan from file
         let proto = serializer::deserialize(path).await?;
         // Check plan equality
-        let plan = from_substrait_plan(&ctx, &proto).await?;
+        let plan = from_substrait_plan(&ctx.state(), &proto).await?;
         let plan_str_ref = format!("{plan_ref}");
         let plan_str = format!("{plan}");
         assert_eq!(plan_str_ref, plan_str);
@@ -60,7 +60,7 @@ mod tests {
         let ctx = create_context().await?;
         let table = provider_as_source(ctx.table_provider("data").await?);
         let table_scan = LogicalPlanBuilder::scan("data", table, 
None)?.build()?;
-        let convert_result = to_substrait_plan(&table_scan, &ctx);
+        let convert_result = to_substrait_plan(&table_scan, &ctx.state());
         assert!(convert_result.is_ok());
 
         Ok(())
@@ -78,7 +78,9 @@ mod tests {
             \n  TableScan: data projection=[a, b]",
         );
 
-        let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+        let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
+            .as_ref()
+            .clone();
 
         let relation = plan.relations.first().unwrap().rel_type.as_ref();
         let root_rel = match relation {
@@ -121,7 +123,9 @@ mod tests {
             \n    TableScan: data projection=[a, b, c]",
         );
 
-        let plan = to_substrait_plan(&datafusion_plan, &ctx)?.as_ref().clone();
+        let plan = to_substrait_plan(&datafusion_plan, &ctx.state())?
+            .as_ref()
+            .clone();
 
         let relation = plan.relations.first().unwrap().rel_type.as_ref();
         let root_rel = match relation {
diff --git a/datafusion/substrait/tests/cases/substrait_validations.rs 
b/datafusion/substrait/tests/cases/substrait_validations.rs
index 5ae586afe5..c77bf1489f 100644
--- a/datafusion/substrait/tests/cases/substrait_validations.rs
+++ b/datafusion/substrait/tests/cases/substrait_validations.rs
@@ -65,7 +65,7 @@ mod tests {
                 vec![("a", DataType::Int32, false), ("b", DataType::Int32, 
true)];
 
             let ctx = generate_context_with_table("DATA", df_schema)?;
-            let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+            let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
             assert_eq!(
                 format!("{}", plan),
@@ -86,7 +86,7 @@ mod tests {
                 ("c", DataType::Int32, false),
             ];
             let ctx = generate_context_with_table("DATA", df_schema)?;
-            let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+            let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
             assert_eq!(
                 format!("{}", plan),
@@ -109,7 +109,7 @@ mod tests {
                 ("b", DataType::Int32, false),
             ];
             let ctx = generate_context_with_table("DATA", df_schema)?;
-            let plan = from_substrait_plan(&ctx, &proto_plan).await?;
+            let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;
 
             assert_eq!(
                 format!("{}", plan),
@@ -128,7 +128,7 @@ mod tests {
                 vec![("a", DataType::Int32, false), ("c", DataType::Int32, 
true)];
 
             let ctx = generate_context_with_table("DATA", df_schema)?;
-            let res = from_substrait_plan(&ctx, &proto_plan).await;
+            let res = from_substrait_plan(&ctx.state(), &proto_plan).await;
             assert!(res.is_err());
             Ok(())
         }
@@ -140,7 +140,7 @@ mod tests {
 
             let ctx =
                 generate_context_with_table("DATA", vec![("a", 
DataType::Date32, true)])?;
-            let res = from_substrait_plan(&ctx, &proto_plan).await;
+            let res = from_substrait_plan(&ctx.state(), &proto_plan).await;
             assert!(res.is_err());
             Ok(())
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to