This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5cb1917f8c Add example for writing SQL analysis using DataFusion
structures (#10938)
5cb1917f8c is described below
commit 5cb1917f8c911184b030599f0a3cb86e2d602031
Author: Lorrens Pantelis <[email protected]>
AuthorDate: Wed Jun 19 01:46:54 2024 +0200
Add example for writing SQL analysis using DataFusion structures (#10938)
* sql analysis example
* update examples readme
* update comments
* Update datafusion-examples/examples/sql_analysis.rs
Co-authored-by: Andrew Lamb <[email protected]>
* apply feedback
* Run tapelo to fix Cargo.toml formatting
* Tweak comments
---------
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion-examples/Cargo.toml | 1 +
datafusion-examples/README.md | 1 +
datafusion-examples/examples/sql_analysis.rs | 309 +++++++++++++++++++++++++++
3 files changed, 311 insertions(+)
diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml
index 0bcf7c1afc..c96aa7ae39 100644
--- a/datafusion-examples/Cargo.toml
+++ b/datafusion-examples/Cargo.toml
@@ -76,6 +76,7 @@ prost-derive = { version = "0.12", default-features = false }
serde = { version = "1.0.136", features = ["derive"] }
serde_json = { workspace = true }
tempfile = { workspace = true }
+test-utils = { path = "../test-utils" }
tokio = { workspace = true, features = ["rt-multi-thread", "parking_lot"] }
tonic = "0.11"
url = { workspace = true }
diff --git a/datafusion-examples/README.md b/datafusion-examples/README.md
index c34f706adb..6150c551c9 100644
--- a/datafusion-examples/README.md
+++ b/datafusion-examples/README.md
@@ -74,6 +74,7 @@ cargo run --example csv_sql
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User
Defined Aggregate Function (UDAF)
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined
Scalar Function (UDF)
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User
Defined Window Function (UDWF)
+- [`sql_analysis.rs`](examples/sql_analysis.rs): Analyse SQL queries with
DataFusion structures
- [`sql_dialect.rs`](examples/sql_dialect.rs): Example of implementing a
custom SQL dialect on top of `DFParser`
- [`to_char.rs`](examples/to_char.rs): Examples of using the to_char function
- [`to_timestamp.rs`](examples/to_timestamp.rs): Examples of using
to_timestamp functions
diff --git a/datafusion-examples/examples/sql_analysis.rs
b/datafusion-examples/examples/sql_analysis.rs
new file mode 100644
index 0000000000..3995988751
--- /dev/null
+++ b/datafusion-examples/examples/sql_analysis.rs
@@ -0,0 +1,309 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This example shows how to use the structures that DataFusion provides to
perform
+//! Analysis on SQL queries and their plans.
+//!
+//! As a motivating example, we show how to count the number of JOINs in a
query
+//! as well as how many join tree's there are with their respective join count
+
+use std::sync::Arc;
+
+use datafusion::common::Result;
+use datafusion::{
+ datasource::MemTable,
+ execution::context::{SessionConfig, SessionContext},
+};
+use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_expr::LogicalPlan;
+use test_utils::tpcds::tpcds_schemas;
+
+/// Counts the total number of joins in a plan
+fn total_join_count(plan: &LogicalPlan) -> usize {
+ let mut total = 0;
+
+ // We can use the TreeNode API to walk over a LogicalPlan.
+ plan.apply(|node| {
+ // if we encounter a join we update the running count
+ if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
+ total += 1;
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })
+ .unwrap();
+
+ total
+}
+
+/// Counts the total number of joins in a plan and collects every join tree in
+/// the plan with their respective join count.
+///
+/// Join Tree Definition: the largest subtree consisting entirely of joins
+///
+/// For example, this plan:
+///
+/// ```text
+/// JOIN
+/// / \
+/// A JOIN
+/// / \
+/// B C
+/// ```
+///
+/// has a single join tree `(A-B-C)` which will result in `(2, [2])`
+///
+/// This plan:
+///
+/// ```text
+/// JOIN
+/// / \
+/// A GROUP
+/// |
+/// JOIN
+/// / \
+/// B C
+/// ```
+///
+/// Has two join trees `(A-, B-C)` which will result in `(2, [1, 1])`
+fn count_trees(plan: &LogicalPlan) -> (usize, Vec<usize>) {
+ // this works the same way as `total_count`, but now when we encounter a
Join
+ // we try to collect it's entire tree
+ let mut to_visit = vec![plan];
+ let mut total = 0;
+ let mut groups = vec![];
+
+ while let Some(node) = to_visit.pop() {
+ // if we encouter a join, we know were at the root of the tree
+ // count this tree and recurse on it's inputs
+ if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
+ let (group_count, inputs) = count_tree(node);
+ total += group_count;
+ groups.push(group_count);
+ to_visit.extend(inputs);
+ } else {
+ to_visit.extend(node.inputs());
+ }
+ }
+
+ (total, groups)
+}
+
+/// Count the entire join tree and return its inputs using TreeNode API
+///
+/// For example, if this function receives following plan:
+///
+/// ```text
+/// JOIN
+/// / \
+/// A GROUP
+/// |
+/// JOIN
+/// / \
+/// B C
+/// ```
+///
+/// It will return `(1, [A, GROUP])`
+fn count_tree(join: &LogicalPlan) -> (usize, Vec<&LogicalPlan>) {
+ let mut inputs = Vec::new();
+ let mut total = 0;
+
+ join.apply(|node| {
+ // Some extra knowledge:
+ //
+ // optimized plans have their projections pushed down as far as
+ // possible, which sometimes results in a projection going in between 2
+ // subsequent joins giving the illusion these joins are not "related",
+ // when in fact they are.
+ //
+ // This plan:
+ // JOIN
+ // / \
+ // A PROJECTION
+ // |
+ // JOIN
+ // / \
+ // B C
+ //
+ // is the same as:
+ //
+ // JOIN
+ // / \
+ // A JOIN
+ // / \
+ // B C
+ // we can continue the recursion in this case
+ if let LogicalPlan::Projection(_) = node {
+ return Ok(TreeNodeRecursion::Continue);
+ }
+
+ // any join we count
+ if matches!(node, LogicalPlan::Join(_) | LogicalPlan::CrossJoin(_)) {
+ total += 1;
+ Ok(TreeNodeRecursion::Continue)
+ } else {
+ inputs.push(node);
+ // skip children of input node
+ Ok(TreeNodeRecursion::Jump)
+ }
+ })
+ .unwrap();
+
+ (total, inputs)
+}
+
+#[tokio::main]
+async fn main() -> Result<()> {
+ // To show how we can count the joins in a sql query we'll be using query
88
+ // from the TPC-DS benchmark.
+ //
+ // q8 has many joins, cross-joins and multiple join-trees, perfect for our
+ // example:
+
+ let tpcds_query_88 = "
+select *
+from
+ (select count(*) h8_30_to_9
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 8
+ and time_dim.t_minute >= 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s1,
+ (select count(*) h9_to_9_30
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 9
+ and time_dim.t_minute < 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s2,
+ (select count(*) h9_30_to_10
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 9
+ and time_dim.t_minute >= 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s3,
+ (select count(*) h10_to_10_30
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 10
+ and time_dim.t_minute < 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s4,
+ (select count(*) h10_30_to_11
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 10
+ and time_dim.t_minute >= 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s5,
+ (select count(*) h11_to_11_30
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 11
+ and time_dim.t_minute < 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s6,
+ (select count(*) h11_30_to_12
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 11
+ and time_dim.t_minute >= 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s7,
+ (select count(*) h12_to_12_30
+ from store_sales, household_demographics , time_dim, store
+ where ss_sold_time_sk = time_dim.t_time_sk
+ and ss_hdemo_sk = household_demographics.hd_demo_sk
+ and ss_store_sk = s_store_sk
+ and time_dim.t_hour = 12
+ and time_dim.t_minute < 30
+ and ((household_demographics.hd_dep_count = 3 and
household_demographics.hd_vehicle_count<=3+2) or
+ (household_demographics.hd_dep_count = 0 and
household_demographics.hd_vehicle_count<=0+2) or
+ (household_demographics.hd_dep_count = 1 and
household_demographics.hd_vehicle_count<=1+2))
+ and store.s_store_name = 'ese') s8;";
+
+ // first set up the config
+ let config = SessionConfig::default();
+ let ctx = SessionContext::new_with_config(config);
+
+ // register the tables of the TPC-DS query
+ let tables = tpcds_schemas();
+ for table in tables {
+ ctx.register_table(
+ table.name,
+ Arc::new(MemTable::try_new(Arc::new(table.schema.clone()),
vec![])?),
+ )?;
+ }
+ // We can create a LogicalPlan from a SQL query like this
+ let logical_plan = ctx.sql(tpcds_query_88).await?.into_optimized_plan()?;
+
+ println!(
+ "Optimized Logical Plan:\n\n{}\n",
+ logical_plan.display_indent()
+ );
+ // we can get the total count (query 88 has 31 joins: 7 CROSS joins and 24
INNER joins => 40 input relations)
+ let total_join_count = total_join_count(&logical_plan);
+ assert_eq!(31, total_join_count);
+
+ println!("The plan has {total_join_count} joins.");
+
+ // Furthermore the 24 inner joins are 8 groups of 3 joins with the 7
+ // cross-joins combining them we can get these groups using the
+ // `count_trees` method
+ let (total_join_count, trees) = count_trees(&logical_plan);
+ assert_eq!(
+ (total_join_count, &trees),
+ // query 88 is very straightforward, we know the cross-join group is at
+ // the top of the plan followed by the INNER joins
+ (31, &vec![7, 3, 3, 3, 3, 3, 3, 3, 3])
+ );
+
+ println!(
+ "And following join-trees (number represents join amount in tree):
{trees:?}"
+ );
+
+ Ok(())
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]