NGA-TRAN commented on code in PR #17843: URL: https://github.com/apache/datafusion/pull/17843#discussion_r2528925254
########## datafusion/core/tests/user_defined/relation_planner.rs: ########## @@ -0,0 +1,353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::array::{Int64Array, RecordBatch, StringArray}; +use arrow::datatypes::{DataType, Field, Schema}; +use datafusion::catalog::memory::MemTable; +use datafusion::common::test_util::batches_to_string; +use datafusion::prelude::*; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::logical_plan::builder::LogicalPlanBuilder; +use datafusion_expr::planner::{ + PlannedRelation, RelationPlanner, RelationPlannerContext, RelationPlanning, +}; +use datafusion_expr::Expr; +use datafusion_sql::sqlparser::ast::TableFactor; + +/// A planner that creates an in-memory table with custom values +#[derive(Debug)] +struct CustomValuesPlanner; + +impl RelationPlanner for CustomValuesPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result<RelationPlanning> { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case("custom_values") => + { + let plan = LogicalPlanBuilder::values(vec![ + vec![Expr::Literal(ScalarValue::Int64(Some(1)), None)], + vec![Expr::Literal(ScalarValue::Int64(Some(2)), None)], + vec![Expr::Literal(ScalarValue::Int64(Some(3)), None)], + ])? + .build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that handles string-based tables +#[derive(Debug)] +struct StringTablePlanner; + +impl RelationPlanner for StringTablePlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result<RelationPlanning> { + match relation { + TableFactor::Table { name, alias, .. } + if name.to_string().eq_ignore_ascii_case("colors") => + { + let plan = LogicalPlanBuilder::values(vec![ + vec![Expr::Literal(ScalarValue::Utf8(Some("red".into())), None)], + vec![Expr::Literal(ScalarValue::Utf8(Some("green".into())), None)], + vec![Expr::Literal(ScalarValue::Utf8(Some("blue".into())), None)], + ])? + .build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that intercepts nested joins and plans them recursively +#[derive(Debug)] +struct RecursiveJoinPlanner; + +impl RelationPlanner for RecursiveJoinPlanner { + fn plan_relation( + &self, + relation: TableFactor, + context: &mut dyn RelationPlannerContext, + ) -> Result<RelationPlanning> { + match relation { + TableFactor::NestedJoin { + table_with_joins, + alias, + .. + } if table_with_joins.joins.len() == 1 => { + // Recursively plan both sides using context.plan() + let left = context.plan(table_with_joins.relation.clone())?; + let right = context.plan(table_with_joins.joins[0].relation.clone())?; + + // Create a cross join + let plan = LogicalPlanBuilder::from(left).cross_join(right)?.build()?; + Ok(RelationPlanning::Planned(PlannedRelation::new(plan, alias))) + } + other => Ok(RelationPlanning::Original(other)), + } + } +} + +/// A planner that always returns None to test delegation +#[derive(Debug)] +struct PassThroughPlanner; + +impl RelationPlanner for PassThroughPlanner { + fn plan_relation( + &self, + relation: TableFactor, + _context: &mut dyn RelationPlannerContext, + ) -> Result<RelationPlanning> { + // Always return Original - delegates to next planner or default + Ok(RelationPlanning::Original(relation)) + } +} + +async fn collect_sql(ctx: &SessionContext, sql: &str) -> Vec<RecordBatch> { + ctx.sql(sql).await.unwrap().collect().await.unwrap() +} + +#[tokio::test] +async fn test_custom_planner_handles_relation() { + let ctx = SessionContext::new(); + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + + let results = collect_sql(&ctx, "SELECT * FROM custom_values").await; + + let expected = "\ ++---------+ +| column1 | ++---------+ +| 1 | +| 2 | +| 3 | ++---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_multiple_planners_first_wins() { + let ctx = SessionContext::new(); + + // Register multiple planners - first one wins + ctx.register_relation_planner(Arc::new(CustomValuesPlanner)) + .unwrap(); + ctx.register_relation_planner(Arc::new(StringTablePlanner)) + .unwrap(); + + // CustomValuesPlanner handles this + let results = collect_sql(&ctx, "SELECT * FROM custom_values").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| 1 | +| 2 | +| 3 | ++---------+"; + assert_eq!(batches_to_string(&results), expected); + + // StringTablePlanner handles this + let results = collect_sql(&ctx, "SELECT * FROM colors").await; + let expected = "\ ++---------+ +| column1 | ++---------+ +| red | +| green | +| blue | ++---------+"; + assert_eq!(batches_to_string(&results), expected); +} + +#[tokio::test] +async fn test_planner_delegates_to_default() { + let ctx = SessionContext::new(); + + // Register a planner that always returns None + ctx.register_relation_planner(Arc::new(PassThroughPlanner)) + .unwrap(); + + // Also register a real table + let schema = Arc::new(Schema::new(vec![Field::new( + "value", + DataType::Int64, + true, + )])); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(Int64Array::from(vec![42]))]) + .unwrap(); + let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap(); + ctx.register_table("real_table", Arc::new(table)).unwrap(); + + // PassThroughPlanner returns None, so it delegates to the default planner + let results = collect_sql(&ctx, "SELECT * FROM real_table").await; Review Comment: I know this is just the test for your UDF but I want to understand the purpose. So because the PassThroughPlanner is registered first, it is used first and delegates to real_table? The same behavior if the register the real_table first because it does not have to go thru the pass through? Can you add test with pass through only to ensure we get back the right behavior? Nothing happens or an error message? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
