theirix commented on code in PR #17633: URL: https://github.com/apache/datafusion/pull/17633#discussion_r2373542298
########## datafusion-examples/examples/table_sample.rs: ########## @@ -0,0 +1,1353 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(unused_imports)] + +use datafusion::common::{ + arrow_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, + DFSchema, DFSchemaRef, ResolvedTableReference, Statistics, TableReference, +}; +use datafusion::error::Result; +use datafusion::logical_expr::sqlparser::ast::{ + Query, SetExpr, Statement, TableFactor, TableSample, TableSampleMethod, + TableSampleQuantity, TableSampleUnit, +}; +use datafusion::logical_expr::{ + AggregateUDF, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + ScalarUDF, TableSource, UserDefinedLogicalNode, UserDefinedLogicalNodeCore, + WindowUDF, +}; +use std::any::Any; +use std::cmp::Ordering; +use std::collections::{BTreeMap, HashMap}; + +use arrow::util::pretty::{pretty_format_batches, pretty_format_batches_with_schema}; +use datafusion::common::tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRewriter, +}; + +use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, TimeUnit}; +use async_trait::async_trait; +use datafusion::catalog::cte_worktable::CteWorkTable; +use datafusion::common::file_options::file_type::FileType; +use datafusion::config::ConfigOptions; +use datafusion::datasource::file_format::format_as_file_type; +use datafusion::datasource::{provider_as_source, DefaultTableSource, TableProvider}; +use datafusion::error::DataFusionError; +use datafusion::execution::{ + FunctionRegistry, SendableRecordBatchStream, SessionState, SessionStateBuilder, + TaskContext, +}; +use datafusion::logical_expr::planner::{ + ContextProvider, ExprPlanner, PlannerResult, RawBinaryExpr, TypePlanner, +}; +use datafusion::logical_expr::sqlparser::dialect::PostgreSqlDialect; +use datafusion::logical_expr::sqlparser::parser::Parser; +use datafusion::logical_expr::var_provider::is_system_variables; +use datafusion::optimizer::simplify_expressions::ExprSimplifier; +use datafusion::optimizer::AnalyzerRule; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::physical_plan::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, +}; +use datafusion::physical_plan::{ + displayable, DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, + RecordBatchStream, +}; +use datafusion::physical_planner::{ + DefaultPhysicalPlanner, ExtensionPlanner, PhysicalPlanner, +}; +use datafusion::prelude::*; +use datafusion::sql::planner::{ParserOptions, PlannerContext, SqlToRel}; +use datafusion::sql::sqlparser::ast::{TableSampleKind, TableSampleModifier}; +use datafusion::sql::unparser::ast::{ + DerivedRelationBuilder, QueryBuilder, RelationBuilder, SelectBuilder, +}; +use datafusion::sql::unparser::dialect::CustomDialectBuilder; +use datafusion::sql::unparser::expr_to_sql; +use datafusion::sql::unparser::extension_unparser::UserDefinedLogicalNodeUnparser; +use datafusion::sql::unparser::extension_unparser::{ + UnparseToStatementResult, UnparseWithinStatementResult, +}; +use datafusion::sql::unparser::{plan_to_sql, Unparser}; +use datafusion::variable::VarType; +use log::{debug, info}; +use std::fmt; +use std::fmt::{Debug, Formatter}; +use std::hash::{Hash, Hasher}; +use std::pin::Pin; +use std::str::FromStr; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use arrow::array::UInt32Array; +use arrow::compute; +use arrow::record_batch::RecordBatch; + +use datafusion::execution::context::QueryPlanner; +use datafusion::sql::sqlparser; +use datafusion::sql::sqlparser::ast; +use futures::stream::{Stream, StreamExt}; +use futures::TryStreamExt; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use rand_distr::{Distribution, Poisson}; + +/// This example demonstrates the table sample support. + +#[derive(Debug, Clone)] +struct TableSamplePlanNode { + inner_plan: LogicalPlan, + + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, +} + +impl Hash for TableSamplePlanNode { + fn hash<H: Hasher>(&self, state: &mut H) { + self.inner_plan.hash(state); + self.lower_bound.to_bits().hash(state); + self.upper_bound.to_bits().hash(state); + self.with_replacement.hash(state); + self.seed.hash(state); + } +} + +impl PartialEq for TableSamplePlanNode { + fn eq(&self, other: &Self) -> bool { + self.inner_plan == other.inner_plan + && (self.lower_bound - other.lower_bound).abs() < f64::EPSILON + && (self.upper_bound - other.upper_bound).abs() < f64::EPSILON + && self.with_replacement == other.with_replacement + && self.seed == other.seed + } +} + +impl Eq for TableSamplePlanNode {} + +impl PartialOrd for TableSamplePlanNode { + fn partial_cmp(&self, other: &Self) -> Option<Ordering> { + self.inner_plan + .partial_cmp(&other.inner_plan) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.lower_bound + .partial_cmp(&other.lower_bound) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.upper_bound.partial_cmp(&other.upper_bound).and_then( + |ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.with_replacement + .partial_cmp(&other.with_replacement) + .and_then(|ord| { + if ord != Ordering::Equal { + Some(ord) + } else { + self.seed.partial_cmp(&other.seed) + } + }) + } + }, + ) + } + }) + } + }) + } +} + +impl UserDefinedLogicalNodeCore for TableSamplePlanNode { + fn name(&self) -> &str { + "TableSample" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.inner_plan] + } + + fn schema(&self) -> &DFSchemaRef { + self.inner_plan.schema() + } + + fn expressions(&self) -> Vec<Expr> { + vec![] + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> fmt::Result { + f.write_fmt(format_args!( + "Sample: {:?} {:?} {:?}", + self.lower_bound, self.upper_bound, self.seed + )) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec<Expr>, + inputs: Vec<LogicalPlan>, + ) -> Result<Self> { + let input = inputs + .first() + .ok_or(DataFusionError::Plan("Should have input".into()))?; + Ok(Self { + inner_plan: input.clone(), + lower_bound: self.lower_bound, + upper_bound: self.upper_bound, + with_replacement: self.with_replacement, + seed: self.seed, + }) + } +} + +/// Execution planner with `SampleExec` for `TableSamplePlanNode` +struct TableSampleExtensionPlanner {} + +impl TableSampleExtensionPlanner { + fn build_execution_plan( + &self, + specific_node: &TableSamplePlanNode, + physical_input: Arc<dyn ExecutionPlan>, + ) -> Result<Arc<dyn ExecutionPlan>> { + Ok(Arc::new(SampleExec { + input: physical_input.clone(), + lower_bound: 0.0, + upper_bound: specific_node.upper_bound, + with_replacement: specific_node.with_replacement, + seed: specific_node.seed, + metrics: Default::default(), + cache: SampleExec::compute_properties(&physical_input), + })) + } +} + +#[async_trait] +impl ExtensionPlanner for TableSampleExtensionPlanner { + /// Create a physical plan for an extension node + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc<dyn ExecutionPlan>], + _session_state: &SessionState, + ) -> Result<Option<Arc<dyn ExecutionPlan>>> { + if let Some(specific_node) = node.as_any().downcast_ref::<TableSamplePlanNode>() { + info!("Extension planner plan_extension: {:?}", &logical_inputs); + assert_eq!(logical_inputs.len(), 1, "Inconsistent number of inputs"); + assert_eq!(physical_inputs.len(), 1, "Inconsistent number of inputs"); + + let exec_plan = + self.build_execution_plan(specific_node, physical_inputs[0].clone())?; + Ok(Some(exec_plan)) + } else { + Ok(None) + } + } +} + +/// Query planner supporting a `TableSampleExtensionPlanner` +#[derive(Debug)] +struct TableSampleQueryPlanner {} + +#[async_trait] +impl QueryPlanner for TableSampleQueryPlanner { + /// Given a `LogicalPlan` created from above, create an + /// `ExecutionPlan` suitable for execution + async fn create_physical_plan( + &self, + logical_plan: &LogicalPlan, + session_state: &SessionState, + ) -> Result<Arc<dyn ExecutionPlan>> { + // Additional extension for table sample node + let physical_planner = + DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( + TableSampleExtensionPlanner {}, + )]); + // Delegate most work of physical planning to the default physical planner + physical_planner + .create_physical_plan(logical_plan, session_state) + .await + } +} + +/// Physical plan implementation +trait Sampler: Send + Sync { + fn sample(&mut self, batch: &RecordBatch) -> Result<RecordBatch>; +} + +struct BernoulliSampler { + lower_bound: f64, + upper_bound: f64, + rng: StdRng, +} + +impl BernoulliSampler { + fn new(lower_bound: f64, upper_bound: f64, seed: u64) -> Self { + Self { + lower_bound, + upper_bound, + rng: StdRng::seed_from_u64(seed), + } + } +} + +impl Sampler for BernoulliSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result<RecordBatch> { + if self.upper_bound <= self.lower_bound { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let rnd: f64 = self.rng.random(); + + if rnd >= self.lower_bound && rnd < self.upper_bound { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +struct PoissonSampler { + ratio: f64, + poisson: Poisson<f64>, + rng: StdRng, +} + +impl PoissonSampler { + fn try_new(ratio: f64, seed: u64) -> Result<Self> { + let poisson = Poisson::new(ratio).map_err(|e| plan_datafusion_err!("{}", e))?; + Ok(Self { + ratio, + poisson, + rng: StdRng::seed_from_u64(seed), + }) + } +} + +impl Sampler for PoissonSampler { + fn sample(&mut self, batch: &RecordBatch) -> Result<RecordBatch> { + if self.ratio <= 0.0 { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let mut indices = Vec::new(); + + for i in 0..batch.num_rows() { + let k = self.poisson.sample(&mut self.rng) as i32; + for _ in 0..k { + indices.push(i as u32); + } + } + + if indices.is_empty() { + return Ok(RecordBatch::new_empty(batch.schema())); + } + + let indices = UInt32Array::from(indices); + compute::take_record_batch(batch, &indices).map_err(|e| e.into()) + } +} + +/// SampleExec samples rows from its input based on a sampling method. +/// This is used to implement SQL `SAMPLE` clause. +#[derive(Debug, Clone)] +pub struct SampleExec { + /// The input plan + input: Arc<dyn ExecutionPlan>, + /// The lower bound of the sampling ratio + lower_bound: f64, + /// The upper bound of the sampling ratio + upper_bound: f64, + /// Whether to sample with replacement + with_replacement: bool, + /// Random seed for reproducible sampling + seed: u64, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, + /// Properties equivalence properties, partitioning, etc. + cache: PlanProperties, +} + +impl SampleExec { + /// Create a new SampleExec with a custom sampling method + pub fn try_new( + input: Arc<dyn ExecutionPlan>, + lower_bound: f64, + upper_bound: f64, + with_replacement: bool, + seed: u64, + ) -> Result<Self> { + if lower_bound < 0.0 || upper_bound > 1.0 || lower_bound > upper_bound { + return internal_err!( + "Sampling bounds must be between 0.0 and 1.0, and lower_bound <= upper_bound, got [{}, {}]", + lower_bound, upper_bound + ); + } + + let cache = Self::compute_properties(&input); + + Ok(Self { + input, + lower_bound, + upper_bound, + with_replacement, + seed, + metrics: ExecutionPlanMetricsSet::new(), + cache, + }) + } + + fn create_sampler(&self, partition: usize) -> Result<Box<dyn Sampler>> { + if self.with_replacement { + Ok(Box::new(PoissonSampler::try_new( + self.upper_bound - self.lower_bound, + self.seed + partition as u64, + )?)) + } else { + Ok(Box::new(BernoulliSampler::new( + self.lower_bound, + self.upper_bound, + self.seed + partition as u64, + ))) + } + } + + /// Whether to sample with replacement + pub fn with_replacement(&self) -> bool { + self.with_replacement + } + + /// The lower bound of the sampling ratio + pub fn lower_bound(&self) -> f64 { + self.lower_bound + } + + /// The upper bound of the sampling ratio + pub fn upper_bound(&self) -> f64 { + self.upper_bound + } + + /// The random seed + pub fn seed(&self) -> u64 { + self.seed + } + + /// The input plan + pub fn input(&self) -> &Arc<dyn ExecutionPlan> { + &self.input + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(input: &Arc<dyn ExecutionPlan>) -> PlanProperties { + input + .properties() + .clone() + .with_eq_properties(EquivalenceProperties::new(input.schema())) + } +} + +impl DisplayAs for SampleExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + DisplayFormatType::TreeRender => { + write!( + f, + "SampleExec: lower_bound={}, upper_bound={}, with_replacement={}, seed={}", + self.lower_bound, self.upper_bound, self.with_replacement, self.seed + ) + } + } + } +} + +impl ExecutionPlan for SampleExec { + fn name(&self) -> &'static str { + "SampleExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn maintains_input_order(&self) -> Vec<bool> { + vec![false] // Sampling does not maintain input order + } + + fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { + vec![&self.input] + } + + fn with_new_children( + self: Arc<Self>, + children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + Ok(Arc::new(SampleExec::try_new( + Arc::clone(&children[0]), + self.lower_bound, + self.upper_bound, + self.with_replacement, + self.seed, + )?)) + } + + fn execute( + &self, + partition: usize, + context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let input_stream = self.input.execute(partition, context)?; + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + Ok(Box::pin(SampleExecStream { + input: input_stream, + sampler: self.create_sampler(partition)?, + baseline_metrics, + })) + } + + fn metrics(&self) -> Option<MetricsSet> { + Some(self.metrics.clone_inner()) + } + + fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> { + let input_stats = self.input.partition_statistics(partition)?; + + // Apply sampling ratio to statistics + let mut stats = input_stats; + let ratio = self.upper_bound - self.lower_bound; + + stats.num_rows = stats + .num_rows + .map(|nr| (nr as f64 * ratio) as usize) + .to_inexact(); + stats.total_byte_size = stats + .total_byte_size + .map(|tb| (tb as f64 * ratio) as usize) + .to_inexact(); + + Ok(stats) + } +} + +/// Stream for the SampleExec operator +struct SampleExecStream { + /// The input stream + input: SendableRecordBatchStream, + /// The sampling method + sampler: Box<dyn Sampler>, + /// Runtime metrics recording + baseline_metrics: BaselineMetrics, +} + +impl Stream for SampleExecStream { + type Item = Result<RecordBatch>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Self::Item>> { + let poll = self.input.poll_next_unpin(cx); + let baseline_metrics = &mut self.baseline_metrics; + + match poll { + Poll::Ready(Some(Ok(batch))) => { + let start = baseline_metrics.elapsed_compute().clone(); + let result = self.sampler.sample(&batch); + let _timer = start.timer(); + Poll::Ready(Some(result)) + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } Review Comment: It's really useful, thank you. Switched to it and recorded additional metrics -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
