This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new c4fd7545ba Add catalog::resolve_table_references (#10876)
c4fd7545ba is described below

commit c4fd7545ba7719d6d12473694fcdf6f34d25b8cb
Author: Leonardo Yvens <[email protected]>
AuthorDate: Mon Jun 17 12:17:58 2024 +0100

    Add catalog::resolve_table_references (#10876)
    
    * resolve information_schema references only when necessary
    
    * add `catalog::resolve_table_references` as a public utility
    
    * collect CTEs separately in resolve_table_references
    
    * test CTE name shadowing
    
    * handle CTE name shadowing in resolve_table_references
    
    * handle unions, recursive and nested CTEs in resolve_table_references
---
 datafusion/core/src/catalog/mod.rs             | 239 ++++++++++++++++++++++++-
 datafusion/core/src/execution/session_state.rs |  96 ++--------
 datafusion/sqllogictest/test_files/cte.slt     |   7 +
 3 files changed, 256 insertions(+), 86 deletions(-)

diff --git a/datafusion/core/src/catalog/mod.rs 
b/datafusion/core/src/catalog/mod.rs
index 209d9b2af2..53b1333399 100644
--- a/datafusion/core/src/catalog/mod.rs
+++ b/datafusion/core/src/catalog/mod.rs
@@ -27,6 +27,8 @@ use crate::catalog::schema::SchemaProvider;
 use dashmap::DashMap;
 use datafusion_common::{exec_err, not_impl_err, Result};
 use std::any::Any;
+use std::collections::BTreeSet;
+use std::ops::ControlFlow;
 use std::sync::Arc;
 
 /// Represent a list of named [`CatalogProvider`]s.
@@ -157,11 +159,11 @@ impl CatalogProviderList for MemoryCatalogProviderList {
 /// access required to read table details (e.g. statistics).
 ///
 /// The pattern that DataFusion itself uses to plan SQL queries is to walk over
-/// the query to [find all schema / table references in an `async` function],
+/// the query to [find all table references],
 /// performing required remote catalog in parallel, and then plans the query
 /// using that snapshot.
 ///
-/// [find all schema / table references in an `async` function]: 
crate::execution::context::SessionState::resolve_table_references
+/// [find all table references]: resolve_table_references
 ///
 /// # Example Catalog Implementations
 ///
@@ -295,6 +297,182 @@ impl CatalogProvider for MemoryCatalogProvider {
     }
 }
 
+/// Collects all tables and views referenced in the SQL statement. CTEs are 
collected separately.
+/// This can be used to determine which tables need to be in the catalog for a 
query to be planned.
+///
+/// # Returns
+///
+/// A `(table_refs, ctes)` tuple, the first element contains table and view 
references and the second
+/// element contains any CTE aliases that were defined and possibly referenced.
+///
+/// ## Example
+///
+/// ```
+/// # use datafusion_sql::parser::DFParser;
+/// # use datafusion::catalog::resolve_table_references;
+/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
+/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
+/// let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();
+/// assert_eq!(table_refs.len(), 2);
+/// assert_eq!(table_refs[0].to_string(), "bar");
+/// assert_eq!(table_refs[1].to_string(), "foo");
+/// assert_eq!(ctes.len(), 0);
+/// ```
+///
+/// ## Example with CTEs  
+///  
+/// ```  
+/// # use datafusion_sql::parser::DFParser;  
+/// # use datafusion::catalog::resolve_table_references;  
+/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;";  
+/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();  
+/// let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();  
+/// assert_eq!(table_refs.len(), 0);
+/// assert_eq!(ctes.len(), 1);  
+/// assert_eq!(ctes[0].to_string(), "my_cte");  
+/// ```
+pub fn resolve_table_references(
+    statement: &datafusion_sql::parser::Statement,
+    enable_ident_normalization: bool,
+) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
+    use crate::sql::planner::object_name_to_table_reference;
+    use datafusion_sql::parser::{
+        CopyToSource, CopyToStatement, Statement as DFStatement,
+    };
+    use information_schema::INFORMATION_SCHEMA;
+    use information_schema::INFORMATION_SCHEMA_TABLES;
+    use sqlparser::ast::*;
+
+    struct RelationVisitor {
+        relations: BTreeSet<ObjectName>,
+        all_ctes: BTreeSet<ObjectName>,
+        ctes_in_scope: Vec<ObjectName>,
+    }
+
+    impl RelationVisitor {
+        /// Record the reference to `relation`, if it's not a CTE reference.
+        fn insert_relation(&mut self, relation: &ObjectName) {
+            if !self.relations.contains(relation)
+                && !self.ctes_in_scope.contains(relation)
+            {
+                self.relations.insert(relation.clone());
+            }
+        }
+    }
+
+    impl Visitor for RelationVisitor {
+        type Break = ();
+
+        fn pre_visit_relation(&mut self, relation: &ObjectName) -> 
ControlFlow<()> {
+            self.insert_relation(relation);
+            ControlFlow::Continue(())
+        }
+
+        fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
+            if let Some(with) = &q.with {
+                for cte in &with.cte_tables {
+                    // The non-recursive CTE name is not in scope when 
evaluating the CTE itself, so this is valid:
+                    // `WITH t AS (SELECT * FROM t) SELECT * FROM t`
+                    // Where the first `t` refers to a predefined table. So we 
are careful here
+                    // to visit the CTE first, before putting it in scope.
+                    if !with.recursive {
+                        // This is a bit hackish as the CTE will be visited 
again as part of visiting `q`,
+                        // but thankfully `insert_relation` is idempotent.
+                        cte.visit(self);
+                    }
+                    self.ctes_in_scope
+                        .push(ObjectName(vec![cte.alias.name.clone()]));
+                }
+            }
+            ControlFlow::Continue(())
+        }
+
+        fn post_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
+            if let Some(with) = &q.with {
+                for _ in &with.cte_tables {
+                    // Unwrap: We just pushed these in `pre_visit_query`
+                    self.all_ctes.insert(self.ctes_in_scope.pop().unwrap());
+                }
+            }
+            ControlFlow::Continue(())
+        }
+
+        fn pre_visit_statement(&mut self, statement: &Statement) -> 
ControlFlow<()> {
+            if let Statement::ShowCreate {
+                obj_type: ShowCreateObject::Table | ShowCreateObject::View,
+                obj_name,
+            } = statement
+            {
+                self.insert_relation(obj_name)
+            }
+
+            // SHOW statements will later be rewritten into a SELECT from the 
information_schema
+            let requires_information_schema = matches!(
+                statement,
+                Statement::ShowFunctions { .. }
+                    | Statement::ShowVariable { .. }
+                    | Statement::ShowStatus { .. }
+                    | Statement::ShowVariables { .. }
+                    | Statement::ShowCreate { .. }
+                    | Statement::ShowColumns { .. }
+                    | Statement::ShowTables { .. }
+                    | Statement::ShowCollation { .. }
+            );
+            if requires_information_schema {
+                for s in INFORMATION_SCHEMA_TABLES {
+                    self.relations.insert(ObjectName(vec![
+                        Ident::new(INFORMATION_SCHEMA),
+                        Ident::new(*s),
+                    ]));
+                }
+            }
+            ControlFlow::Continue(())
+        }
+    }
+
+    let mut visitor = RelationVisitor {
+        relations: BTreeSet::new(),
+        all_ctes: BTreeSet::new(),
+        ctes_in_scope: vec![],
+    };
+
+    fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) 
{
+        match statement {
+            DFStatement::Statement(s) => {
+                let _ = s.as_ref().visit(visitor);
+            }
+            DFStatement::CreateExternalTable(table) => {
+                visitor
+                    .relations
+                    
.insert(ObjectName(vec![Ident::from(table.name.as_str())]));
+            }
+            DFStatement::CopyTo(CopyToStatement { source, .. }) => match 
source {
+                CopyToSource::Relation(table_name) => {
+                    visitor.insert_relation(table_name);
+                }
+                CopyToSource::Query(query) => {
+                    query.visit(visitor);
+                }
+            },
+            DFStatement::Explain(explain) => 
visit_statement(&explain.statement, visitor),
+        }
+    }
+
+    visit_statement(statement, &mut visitor);
+
+    let table_refs = visitor
+        .relations
+        .into_iter()
+        .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
+        .collect::<datafusion_common::Result<_>>()?;
+    let ctes = visitor
+        .all_ctes
+        .into_iter()
+        .map(|x| object_name_to_table_reference(x, enable_ident_normalization))
+        .collect::<datafusion_common::Result<_>>()?;
+    Ok((table_refs, ctes))
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -363,4 +541,61 @@ mod tests {
         let cat = Arc::new(MemoryCatalogProvider::new()) as Arc<dyn 
CatalogProvider>;
         assert!(cat.deregister_schema("foo", false).unwrap().is_none());
     }
+
+    #[test]
+    fn resolve_table_references_shadowed_cte() {
+        use datafusion_sql::parser::DFParser;
+
+        // An interesting edge case where the `t` name is used both as an 
ordinary table reference
+        // and as a CTE reference.
+        let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
+        let statement = 
DFParser::parse_sql(query).unwrap().pop_back().unwrap();
+        let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();
+        assert_eq!(table_refs.len(), 1);
+        assert_eq!(ctes.len(), 1);
+        assert_eq!(ctes[0].to_string(), "t");
+        assert_eq!(table_refs[0].to_string(), "t");
+
+        // UNION is a special case where the CTE is not in scope for the 
second branch.
+        let query = "(with t as (select 1) select * from t) union (select * 
from t)";
+        let statement = 
DFParser::parse_sql(query).unwrap().pop_back().unwrap();
+        let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();
+        assert_eq!(table_refs.len(), 1);
+        assert_eq!(ctes.len(), 1);
+        assert_eq!(ctes[0].to_string(), "t");
+        assert_eq!(table_refs[0].to_string(), "t");
+
+        // Nested CTEs are also handled.
+        // Here the first `u` is a CTE, but the second `u` is a table 
reference.
+        // While `t` is always a CTE.
+        let query = "(with t as (with u as (select 1) select * from u) select 
* from u cross join t)";
+        let statement = 
DFParser::parse_sql(query).unwrap().pop_back().unwrap();
+        let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();
+        assert_eq!(table_refs.len(), 1);
+        assert_eq!(ctes.len(), 2);
+        assert_eq!(ctes[0].to_string(), "t");
+        assert_eq!(ctes[1].to_string(), "u");
+        assert_eq!(table_refs[0].to_string(), "u");
+    }
+
+    #[test]
+    fn resolve_table_references_recursive_cte() {
+        use datafusion_sql::parser::DFParser;
+
+        let query = "
+            WITH RECURSIVE nodes AS ( 
+                SELECT 1 as id
+                UNION ALL 
+                SELECT id + 1 as id 
+                FROM nodes
+                WHERE id < 10
+            )
+            SELECT * FROM nodes
+        ";
+        let statement = 
DFParser::parse_sql(query).unwrap().pop_back().unwrap();
+        let (table_refs, ctes) = resolve_table_references(&statement, 
true).unwrap();
+        assert_eq!(table_refs.len(), 0);
+        assert_eq!(ctes.len(), 1);
+        assert_eq!(ctes[0].to_string(), "nodes");
+    }
 }
diff --git a/datafusion/core/src/execution/session_state.rs 
b/datafusion/core/src/execution/session_state.rs
index fed101bd23..1df77a1f9e 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -66,15 +66,12 @@ use datafusion_optimizer::{
 use datafusion_physical_expr::create_physical_expr;
 use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
 use datafusion_physical_plan::ExecutionPlan;
-use datafusion_sql::parser::{CopyToSource, CopyToStatement, DFParser, 
Statement};
-use datafusion_sql::planner::{
-    object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel,
-};
+use datafusion_sql::parser::{DFParser, Statement};
+use datafusion_sql::planner::{ContextProvider, ParserOptions, SqlToRel};
 use sqlparser::dialect::dialect_from_str;
 use std::collections::hash_map::Entry;
 use std::collections::{HashMap, HashSet};
 use std::fmt::Debug;
-use std::ops::ControlFlow;
 use std::sync::Arc;
 use url::Url;
 use uuid::Uuid;
@@ -493,91 +490,22 @@ impl SessionState {
         Ok(statement)
     }
 
-    /// Resolve all table references in the SQL statement.
+    /// Resolve all table references in the SQL statement. Does not include 
CTE references.
+    ///
+    /// See [`catalog::resolve_table_references`] for more information.
+    ///
+    /// [`catalog::resolve_table_references`]: 
crate::catalog::resolve_table_references
     pub fn resolve_table_references(
         &self,
         statement: &datafusion_sql::parser::Statement,
     ) -> datafusion_common::Result<Vec<TableReference>> {
-        use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES;
-        use datafusion_sql::parser::Statement as DFStatement;
-        use sqlparser::ast::*;
-
-        // Getting `TableProviders` is async but planing is not -- thus 
pre-fetch
-        // table providers for all relations referenced in this query
-        let mut relations = hashbrown::HashSet::with_capacity(10);
-
-        struct RelationVisitor<'a>(&'a mut hashbrown::HashSet<ObjectName>);
-
-        impl<'a> RelationVisitor<'a> {
-            /// Record that `relation` was used in this statement
-            fn insert(&mut self, relation: &ObjectName) {
-                self.0.get_or_insert_with(relation, |_| relation.clone());
-            }
-        }
-
-        impl<'a> Visitor for RelationVisitor<'a> {
-            type Break = ();
-
-            fn pre_visit_relation(&mut self, relation: &ObjectName) -> 
ControlFlow<()> {
-                self.insert(relation);
-                ControlFlow::Continue(())
-            }
-
-            fn pre_visit_statement(&mut self, statement: &Statement) -> 
ControlFlow<()> {
-                if let Statement::ShowCreate {
-                    obj_type: ShowCreateObject::Table | ShowCreateObject::View,
-                    obj_name,
-                } = statement
-                {
-                    self.insert(obj_name)
-                }
-                ControlFlow::Continue(())
-            }
-        }
-
-        let mut visitor = RelationVisitor(&mut relations);
-        fn visit_statement(statement: &DFStatement, visitor: &mut 
RelationVisitor<'_>) {
-            match statement {
-                DFStatement::Statement(s) => {
-                    let _ = s.as_ref().visit(visitor);
-                }
-                DFStatement::CreateExternalTable(table) => {
-                    visitor
-                        .0
-                        
.insert(ObjectName(vec![Ident::from(table.name.as_str())]));
-                }
-                DFStatement::CopyTo(CopyToStatement { source, .. }) => match 
source {
-                    CopyToSource::Relation(table_name) => {
-                        visitor.insert(table_name);
-                    }
-                    CopyToSource::Query(query) => {
-                        query.visit(visitor);
-                    }
-                },
-                DFStatement::Explain(explain) => {
-                    visit_statement(&explain.statement, visitor)
-                }
-            }
-        }
-
-        visit_statement(statement, &mut visitor);
-
-        // Always include information_schema if available
-        if self.config.information_schema() {
-            for s in INFORMATION_SCHEMA_TABLES {
-                relations.insert(ObjectName(vec![
-                    Ident::new(INFORMATION_SCHEMA),
-                    Ident::new(*s),
-                ]));
-            }
-        }
-
         let enable_ident_normalization =
             self.config.options().sql_parser.enable_ident_normalization;
-        relations
-            .into_iter()
-            .map(|x| object_name_to_table_reference(x, 
enable_ident_normalization))
-            .collect::<datafusion_common::Result<_>>()
+        let (table_refs, _) = crate::catalog::resolve_table_references(
+            statement,
+            enable_ident_normalization,
+        )?;
+        Ok(table_refs)
     }
 
     /// Convert an AST Statement into a LogicalPlan
diff --git a/datafusion/sqllogictest/test_files/cte.slt 
b/datafusion/sqllogictest/test_files/cte.slt
index 1ff108cf6c..d8eaa51fc8 100644
--- a/datafusion/sqllogictest/test_files/cte.slt
+++ b/datafusion/sqllogictest/test_files/cte.slt
@@ -828,3 +828,10 @@ SELECT * FROM non_recursive_cte, recursive_cte;
 ----
 1 1
 1 3
+
+# Name shadowing:
+# The first `t` refers to the table, the second to the CTE.
+query I
+WITH t AS (SELECT * FROM t where t.a < 2) SELECT * FROM t
+----
+1
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to