This is an automated email from the ASF dual-hosted git repository. jiayuliu pushed a commit to branch move-udf-udaf-expr in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 685ea46926f4b0c2812e026229c0f55036c36928 Author: Jiayu Liu <ji...@hey.com> AuthorDate: Sun Feb 6 13:29:20 2022 +0800 pyarrow --- datafusion-expr/Cargo.toml | 1 + datafusion-expr/src/expr.rs | 703 +++++++++++++++++++++++++++++ datafusion-expr/src/lib.rs | 34 ++ datafusion-expr/src/operator.rs | 72 +++ datafusion-expr/src/udaf.rs | 92 ++++ datafusion-expr/src/udf.rs | 93 ++++ datafusion/src/logical_plan/expr.rs | 684 +--------------------------- datafusion/src/logical_plan/mod.rs | 3 +- datafusion/src/logical_plan/operators.rs | 71 --- datafusion/src/physical_plan/aggregates.rs | 9 - datafusion/src/physical_plan/functions.rs | 17 +- datafusion/src/physical_plan/udaf.rs | 73 +-- datafusion/src/physical_plan/udf.rs | 69 +-- 13 files changed, 1009 insertions(+), 912 deletions(-) diff --git a/datafusion-expr/Cargo.toml b/datafusion-expr/Cargo.toml index 73a5fcd..a6dad52 100644 --- a/datafusion-expr/Cargo.toml +++ b/datafusion-expr/Cargo.toml @@ -38,3 +38,4 @@ path = "src/lib.rs" datafusion-common = { path = "../datafusion-common", version = "6.0.0" } arrow = { version = "8.0.0", features = ["prettyprint"] } sqlparser = "0.13" +ahash = { version = "0.7", default-features = false } diff --git a/datafusion-expr/src/expr.rs b/datafusion-expr/src/expr.rs new file mode 100644 index 0000000..c4c7197 --- /dev/null +++ b/datafusion-expr/src/expr.rs @@ -0,0 +1,703 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregate_function; +use crate::built_in_function; +use crate::window_frame; +use crate::window_function; +use crate::AggregateUDF; +use crate::Operator; +use crate::ScalarUDF; +use arrow::datatypes::DataType; +use datafusion_common::Column; +use datafusion_common::{DFSchema, Result}; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::fmt; +use std::hash::{BuildHasher, Hash, Hasher}; +use std::ops::Not; +use std::sync::Arc; + +/// return a new expression l <op> r +pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { + Expr::BinaryExpr { + left: Box::new(l), + op, + right: Box::new(r), + } +} + +/// `Expr` is a central struct of DataFusion's query API, and +/// represent logical expressions such as `A + 1`, or `CAST(c1 AS +/// int)`. +/// +/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) +/// and nullability, and has functions for building up complex +/// expressions. +/// +/// # Examples +/// +/// ## Create an expression `c1` referring to column named "c1" +/// ``` +/// # use datafusion::logical_plan::*; +/// let expr = col("c1"); +/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); +/// ``` +/// +/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together +/// ``` +/// # use datafusion::logical_plan::*; +/// let expr = col("c1") + col("c2"); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// assert_eq!(*right, col("c2")); +/// assert_eq!(op, Operator::Plus); +/// } +/// ``` +/// +/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42` +/// ``` +/// # use datafusion::logical_plan::*; +/// # use datafusion::scalar::*; +/// let expr = col("c1").eq(lit(42)); +/// +/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); +/// if let Expr::BinaryExpr { left, right, op } = expr { +/// assert_eq!(*left, col("c1")); +/// let scalar = ScalarValue::Int32(Some(42)); +/// assert_eq!(*right, Expr::Literal(scalar)); +/// assert_eq!(op, Operator::Eq); +/// } +/// ``` +#[derive(Clone, PartialEq, Hash)] +pub enum Expr { + /// An expression with a specific name. + Alias(Box<Expr>, String), + /// A named reference to a qualified filed in a schema. + Column(Column), + /// A named reference to a variable in a registry. + ScalarVariable(Vec<String>), + /// A constant value. + Literal(ScalarValue), + /// A binary expression such as "age > 21" + BinaryExpr { + /// Left-hand side of the expression + left: Box<Expr>, + /// The comparison operator + op: Operator, + /// Right-hand side of the expression + right: Box<Expr>, + }, + /// Negation of an expression. The expression's type must be a boolean to make sense. + Not(Box<Expr>), + /// Whether an expression is not Null. This expression is never null. + IsNotNull(Box<Expr>), + /// Whether an expression is Null. This expression is never null. + IsNull(Box<Expr>), + /// arithmetic negation of an expression, the operand must be of a signed numeric data type + Negative(Box<Expr>), + /// Returns the field of a [`ListArray`] or [`StructArray`] by key + GetIndexedField { + /// the expression to take the field from + expr: Box<Expr>, + /// The name of the field to take + key: ScalarValue, + }, + /// Whether an expression is between a given range. + Between { + /// The value to compare + expr: Box<Expr>, + /// Whether the expression is negated + negated: bool, + /// The low end of the range + low: Box<Expr>, + /// The high end of the range + high: Box<Expr>, + }, + /// The CASE expression is similar to a series of nested if/else and there are two forms that + /// can be used. The first form consists of a series of boolean "when" expressions with + /// corresponding "then" expressions, and an optional "else" expression. + /// + /// CASE WHEN condition THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + /// + /// The second form uses a base expression and then a series of "when" clauses that match on a + /// literal value. + /// + /// CASE expression + /// WHEN value THEN result + /// [WHEN ...] + /// [ELSE result] + /// END + Case { + /// Optional base expression that can be compared to literal values in the "when" expressions + expr: Option<Box<Expr>>, + /// One or more when/then expressions + when_then_expr: Vec<(Box<Expr>, Box<Expr>)>, + /// Optional "else" expression + else_expr: Option<Box<Expr>>, + }, + /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + Cast { + /// The expression being cast + expr: Box<Expr>, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// Casts the expression to a given type and will return a null value if the expression cannot be cast. + /// This expression is guaranteed to have a fixed type. + TryCast { + /// The expression being cast + expr: Box<Expr>, + /// The `DataType` the expression will yield + data_type: DataType, + }, + /// A sort expression, that can be used to sort values. + Sort { + /// The expression to sort on + expr: Box<Expr>, + /// The direction of the sort + asc: bool, + /// Whether to put Nulls before all other data values + nulls_first: bool, + }, + /// Represents the call of a built-in scalar function with a set of arguments. + ScalarFunction { + /// The function + fun: built_in_function::BuiltinScalarFunction, + /// List of expressions to feed to the functions as arguments + args: Vec<Expr>, + }, + /// Represents the call of a user-defined scalar function with arguments. + ScalarUDF { + /// The function + fun: Arc<ScalarUDF>, + /// List of expressions to feed to the functions as arguments + args: Vec<Expr>, + }, + /// Represents the call of an aggregate built-in function with arguments. + AggregateFunction { + /// Name of the function + fun: aggregate_function::AggregateFunction, + /// List of expressions to feed to the functions as arguments + args: Vec<Expr>, + /// Whether this is a DISTINCT aggregation or not + distinct: bool, + }, + /// Represents the call of a window function with arguments. + WindowFunction { + /// Name of the function + fun: window_function::WindowFunction, + /// List of expressions to feed to the functions as arguments + args: Vec<Expr>, + /// List of partition by expressions + partition_by: Vec<Expr>, + /// List of order by expressions + order_by: Vec<Expr>, + /// Window frame + window_frame: Option<window_frame::WindowFrame>, + }, + /// aggregate function + AggregateUDF { + /// The function + fun: Arc<AggregateUDF>, + /// List of expressions to feed to the functions as arguments + args: Vec<Expr>, + }, + /// Returns whether the list contains the expr value. + InList { + /// The expression to compare + expr: Box<Expr>, + /// A list of values to compare against + list: Vec<Expr>, + /// Whether the expression is negated + negated: bool, + }, + /// Represents a reference to all fields in a schema. + Wildcard, +} + +/// Fixed seed for the hashing so that Ords are consistent across runs +const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); + +impl PartialOrd for Expr { + fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { + let mut hasher = SEED.build_hasher(); + self.hash(&mut hasher); + let s = hasher.finish(); + + let mut hasher = SEED.build_hasher(); + other.hash(&mut hasher); + let o = hasher.finish(); + + Some(s.cmp(&o)) + } +} + +impl Expr { + /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. + /// + /// This represents how a column with this expression is named when no alias is chosen + pub fn name(&self, input_schema: &DFSchema) -> Result<String> { + create_name(self, input_schema) + } + + /// Return `self == other` + pub fn eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::Eq, other) + } + + /// Return `self != other` + pub fn not_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotEq, other) + } + + /// Return `self > other` + pub fn gt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Gt, other) + } + + /// Return `self >= other` + pub fn gt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::GtEq, other) + } + + /// Return `self < other` + pub fn lt(self, other: Expr) -> Expr { + binary_expr(self, Operator::Lt, other) + } + + /// Return `self <= other` + pub fn lt_eq(self, other: Expr) -> Expr { + binary_expr(self, Operator::LtEq, other) + } + + /// Return `self && other` + pub fn and(self, other: Expr) -> Expr { + binary_expr(self, Operator::And, other) + } + + /// Return `self || other` + pub fn or(self, other: Expr) -> Expr { + binary_expr(self, Operator::Or, other) + } + + /// Return `!self` + #[allow(clippy::should_implement_trait)] + pub fn not(self) -> Expr { + !self + } + + /// Calculate the modulus of two expressions. + /// Return `self % other` + pub fn modulus(self, other: Expr) -> Expr { + binary_expr(self, Operator::Modulo, other) + } + + /// Return `self LIKE other` + pub fn like(self, other: Expr) -> Expr { + binary_expr(self, Operator::Like, other) + } + + /// Return `self NOT LIKE other` + pub fn not_like(self, other: Expr) -> Expr { + binary_expr(self, Operator::NotLike, other) + } + + /// Return `self AS name` alias expression + pub fn alias(self, name: &str) -> Expr { + Expr::Alias(Box::new(self), name.to_owned()) + } + + /// Return `self IN <list>` if `negated` is false, otherwise + /// return `self NOT IN <list>`.a + pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr { + Expr::InList { + expr: Box::new(self), + list, + negated, + } + } + + /// Return `IsNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_null(self) -> Expr { + Expr::IsNull(Box::new(self)) + } + + /// Return `IsNotNull(Box(self)) + #[allow(clippy::wrong_self_convention)] + pub fn is_not_null(self) -> Expr { + Expr::IsNotNull(Box::new(self)) + } + + /// Create a sort expression from an existing expression. + /// + /// ``` + /// # use datafusion::logical_plan::col; + /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST + /// ``` + pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { + Expr::Sort { + expr: Box::new(self), + asc, + nulls_first, + } + } +} + +impl Not for Expr { + type Output = Self; + + fn not(self) -> Self::Output { + Expr::Not(Box::new(self)) + } +} + +impl std::fmt::Display for Expr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Expr::BinaryExpr { + ref left, + ref right, + ref op, + } => write!(f, "{} {} {}", left, op, right), + Expr::AggregateFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + /// Whether this is a DISTINCT aggregation or not + ref distinct, + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::ScalarFunction { + /// Name of the function + ref fun, + /// List of expressions to feed to the functions as arguments + ref args, + } => fmt_function(f, &fun.to_string(), false, args, true), + _ => write!(f, "{:?}", self), + } + } +} + +impl fmt::Debug for Expr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), + Expr::Column(c) => write!(f, "{}", c), + Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), + Expr::Literal(v) => write!(f, "{:?}", v), + Expr::Case { + expr, + when_then_expr, + else_expr, + .. + } => { + write!(f, "CASE ")?; + if let Some(e) = expr { + write!(f, "{:?} ", e)?; + } + for (w, t) in when_then_expr { + write!(f, "WHEN {:?} THEN {:?} ", w, t)?; + } + if let Some(e) = else_expr { + write!(f, "ELSE {:?} ", e)?; + } + write!(f, "END") + } + Expr::Cast { expr, data_type } => { + write!(f, "CAST({:?} AS {:?})", expr, data_type) + } + Expr::TryCast { expr, data_type } => { + write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) + } + Expr::Not(expr) => write!(f, "NOT {:?}", expr), + Expr::Negative(expr) => write!(f, "(- {:?})", expr), + Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), + Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), + Expr::BinaryExpr { left, op, right } => { + write!(f, "{:?} {} {:?}", left, op, right) + } + Expr::Sort { + expr, + asc, + nulls_first, + } => { + if *asc { + write!(f, "{:?} ASC", expr)?; + } else { + write!(f, "{:?} DESC", expr)?; + } + if *nulls_first { + write!(f, " NULLS FIRST") + } else { + write!(f, " NULLS LAST") + } + } + Expr::ScalarFunction { fun, args, .. } => { + fmt_function(f, &fun.to_string(), false, args, false) + } + Expr::ScalarUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + fmt_function(f, &fun.to_string(), false, args, false)?; + if !partition_by.is_empty() { + write!(f, " PARTITION BY {:?}", partition_by)?; + } + if !order_by.is_empty() { + write!(f, " ORDER BY {:?}", order_by)?; + } + if let Some(window_frame) = window_frame { + write!( + f, + " {} BETWEEN {} AND {}", + window_frame.units, window_frame.start_bound, window_frame.end_bound + )?; + } + Ok(()) + } + Expr::AggregateFunction { + fun, + distinct, + ref args, + .. + } => fmt_function(f, &fun.to_string(), *distinct, args, true), + Expr::AggregateUDF { fun, ref args, .. } => { + fmt_function(f, &fun.name, false, args, false) + } + Expr::Between { + expr, + negated, + low, + high, + } => { + if *negated { + write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) + } else { + write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) + } + } + Expr::InList { + expr, + list, + negated, + } => { + if *negated { + write!(f, "{:?} NOT IN ({:?})", expr, list) + } else { + write!(f, "{:?} IN ({:?})", expr, list) + } + } + Expr::Wildcard => write!(f, "*"), + Expr::GetIndexedField { ref expr, key } => { + write!(f, "({:?})[{}]", expr, key) + } + } + } +} + +fn fmt_function( + f: &mut fmt::Formatter, + fun: &str, + distinct: bool, + args: &[Expr], + display: bool, +) -> fmt::Result { + let args: Vec<String> = match display { + true => args.iter().map(|arg| format!("{}", arg)).collect(), + false => args.iter().map(|arg| format!("{:?}", arg)).collect(), + }; + + // let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect(); + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) +} + +fn create_function_name( + fun: &str, + distinct: bool, + args: &[Expr], + input_schema: &DFSchema, +) -> Result<String> { + let names: Vec<String> = args + .iter() + .map(|e| create_name(e, input_schema)) + .collect::<Result<_>>()?; + let distinct_str = match distinct { + true => "DISTINCT ", + false => "", + }; + Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) +} + +/// Returns a readable name of an expression based on the input schema. +/// This function recursively transverses the expression for names such as "CAST(a > 2)". +fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> { + match e { + Expr::Alias(_, name) => Ok(name.clone()), + Expr::Column(c) => Ok(c.flat_name()), + Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), + Expr::Literal(value) => Ok(format!("{:?}", value)), + Expr::BinaryExpr { left, op, right } => { + let left = create_name(left, input_schema)?; + let right = create_name(right, input_schema)?; + Ok(format!("{} {} {}", left, op, right)) + } + Expr::Case { + expr, + when_then_expr, + else_expr, + } => { + let mut name = "CASE ".to_string(); + if let Some(e) = expr { + let e = create_name(e, input_schema)?; + name += &format!("{} ", e); + } + for (w, t) in when_then_expr { + let when = create_name(w, input_schema)?; + let then = create_name(t, input_schema)?; + name += &format!("WHEN {} THEN {} ", when, then); + } + if let Some(e) = else_expr { + let e = create_name(e, input_schema)?; + name += &format!("ELSE {} ", e); + } + name += "END"; + Ok(name) + } + Expr::Cast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("CAST({} AS {:?})", expr, data_type)) + } + Expr::TryCast { expr, data_type } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) + } + Expr::Not(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("NOT {}", expr)) + } + Expr::Negative(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("(- {})", expr)) + } + Expr::IsNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NULL", expr)) + } + Expr::IsNotNull(expr) => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{} IS NOT NULL", expr)) + } + Expr::GetIndexedField { expr, key } => { + let expr = create_name(expr, input_schema)?; + Ok(format!("{}[{}]", expr, key)) + } + Expr::ScalarFunction { fun, args, .. } => { + create_function_name(&fun.to_string(), false, args, input_schema) + } + Expr::ScalarUDF { fun, args, .. } => { + create_function_name(&fun.name, false, args, input_schema) + } + Expr::WindowFunction { + fun, + args, + window_frame, + partition_by, + order_by, + } => { + let mut parts: Vec<String> = vec![create_function_name( + &fun.to_string(), + false, + args, + input_schema, + )?]; + if !partition_by.is_empty() { + parts.push(format!("PARTITION BY {:?}", partition_by)); + } + if !order_by.is_empty() { + parts.push(format!("ORDER BY {:?}", order_by)); + } + if let Some(window_frame) = window_frame { + parts.push(format!("{}", window_frame)); + } + Ok(parts.join(" ")) + } + Expr::AggregateFunction { + fun, + distinct, + args, + .. + } => create_function_name(&fun.to_string(), *distinct, args, input_schema), + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } + Expr::InList { + expr, + list, + negated, + } => { + let expr = create_name(expr, input_schema)?; + let list = list.iter().map(|expr| create_name(expr, input_schema)); + if *negated { + Ok(format!("{} NOT IN ({:?})", expr, list)) + } else { + Ok(format!("{} IN ({:?})", expr, list)) + } + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr = create_name(expr, input_schema)?; + let low = create_name(low, input_schema)?; + let high = create_name(high, input_schema)?; + if *negated { + Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) + } else { + Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) + } + } + Expr::Sort { .. } => Err(DataFusionError::Internal( + "Create name does not support sort expression".to_string(), + )), + Expr::Wildcard => Err(DataFusionError::Internal( + "Create name does not support wildcard".to_string(), + )), + } +} diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs index 2491fcf..fe68276 100644 --- a/datafusion-expr/src/lib.rs +++ b/datafusion-expr/src/lib.rs @@ -19,16 +19,50 @@ mod accumulator; mod aggregate_function; mod built_in_function; mod columnar_value; +mod expr; mod operator; mod signature; +mod udaf; +mod udf; mod window_frame; mod window_function; +use arrow::datatypes::DataType; +use datafusion_common::Result; +use std::sync::Arc; + +/// Scalar function +/// +/// The Fn param is the wrapped function but be aware that the function will +/// be passed with the slice / vec of columnar values (either scalar or array) +/// with the exception of zero param function, where a singular element vec +/// will be passed. In that case the single element is a null array to indicate +/// the batch's row count (so that the generative zero-argument function can know +/// the result array size). +pub type ScalarFunctionImplementation = + Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>; + +/// A function's return type +pub type ReturnTypeFunction = + Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>; + +/// the implementation of an aggregate function +pub type AccumulatorFunctionImplementation = + Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>; + +/// This signature corresponds to which types an aggregator serializes +/// its state, given its return datatype. +pub type StateTypeFunction = + Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>; + pub use accumulator::Accumulator; pub use aggregate_function::AggregateFunction; pub use built_in_function::BuiltinScalarFunction; pub use columnar_value::{ColumnarValue, NullColumnarValue}; +pub use expr::Expr; pub use operator::Operator; pub use signature::{Signature, TypeSignature, Volatility}; +pub use udaf::AggregateUDF; +pub use udf::ScalarUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion-expr/src/operator.rs b/datafusion-expr/src/operator.rs index e6b7e35..11a9a77 100644 --- a/datafusion-expr/src/operator.rs +++ b/datafusion-expr/src/operator.rs @@ -15,7 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::expr::binary_expr; +use crate::Expr; use std::fmt; +use std::ops; /// Operators applied to expressions #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)] @@ -95,3 +98,72 @@ impl fmt::Display for Operator { write!(f, "{}", display) } } + +impl ops::Add for Expr { + type Output = Self; + + fn add(self, rhs: Self) -> Self { + binary_expr(self, Operator::Plus, rhs) + } +} + +impl ops::Sub for Expr { + type Output = Self; + + fn sub(self, rhs: Self) -> Self { + binary_expr(self, Operator::Minus, rhs) + } +} + +impl ops::Mul for Expr { + type Output = Self; + + fn mul(self, rhs: Self) -> Self { + binary_expr(self, Operator::Multiply, rhs) + } +} + +impl ops::Div for Expr { + type Output = Self; + + fn div(self, rhs: Self) -> Self { + binary_expr(self, Operator::Divide, rhs) + } +} + +impl ops::Rem for Expr { + type Output = Self; + + fn rem(self, rhs: Self) -> Self { + binary_expr(self, Operator::Modulo, rhs) + } +} + +#[cfg(test)] +mod tests { + use crate::prelude::lit; + + #[test] + fn test_operators() { + assert_eq!( + format!("{:?}", lit(1u32) + lit(2u32)), + "UInt32(1) + UInt32(2)" + ); + assert_eq!( + format!("{:?}", lit(1u32) - lit(2u32)), + "UInt32(1) - UInt32(2)" + ); + assert_eq!( + format!("{:?}", lit(1u32) * lit(2u32)), + "UInt32(1) * UInt32(2)" + ); + assert_eq!( + format!("{:?}", lit(1u32) / lit(2u32)), + "UInt32(1) / UInt32(2)" + ); + assert_eq!( + format!("{:?}", lit(1u32) % lit(2u32)), + "UInt32(1) % UInt32(2)" + ); + } +} diff --git a/datafusion-expr/src/udaf.rs b/datafusion-expr/src/udaf.rs new file mode 100644 index 0000000..142cfe1 --- /dev/null +++ b/datafusion-expr/src/udaf.rs @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains functions and structs supporting user-defined aggregate functions. + +use crate::Expr; +use crate::{ + AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction, +}; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; + +/// Logical representation of a user-defined aggregate function (UDAF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct AggregateUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub accumulator: AccumulatorFunctionImplementation, + /// the accumulator's state's description as a function of the return type + pub state_type: StateTypeFunction, +} + +impl Debug for AggregateUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"<FUNC>") + .finish() + } +} + +impl PartialEq for AggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for AggregateUDF { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl AggregateUDF { + /// Create a new AggregateUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + accumulator: &AccumulatorFunctionImplementation, + state_type: &StateTypeFunction, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + accumulator: accumulator.clone(), + state_type: state_type.clone(), + } + } + + /// creates a logical expression with a call of the UDAF + /// This utility allows using the UDAF without requiring access to the registry. + pub fn call(&self, args: Vec<Expr>) -> Expr { + Expr::AggregateUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion-expr/src/udf.rs b/datafusion-expr/src/udf.rs new file mode 100644 index 0000000..247d6a2 --- /dev/null +++ b/datafusion-expr/src/udf.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use std::fmt; +use std::fmt::Debug; +use std::fmt::Formatter; +use std::sync::Arc; + +/// Logical representation of a UDF. +#[derive(Clone)] +pub struct ScalarUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + /// + /// The fn param is the wrapped function but be aware that the function will + /// be passed with the slice / vec of columnar values (either scalar or array) + /// with the exception of zero param function, where a singular element vec + /// will be passed. In that case the single element is a null array to indicate + /// the batch's row count (so that the generative zero-argument function can know + /// the result array size). + pub fun: ScalarFunctionImplementation, +} + +impl Debug for ScalarUDF { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("ScalarUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"<FUNC>") + .finish() + } +} + +impl PartialEq for ScalarUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.signature == other.signature + } +} + +impl std::hash::Hash for ScalarUDF { + fn hash<H: std::hash::Hasher>(&self, state: &mut H) { + self.name.hash(state); + self.signature.hash(state); + } +} + +impl ScalarUDF { + /// Create a new ScalarUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &ScalarFunctionImplementation, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + fun: fun.clone(), + } + } + + /// creates a logical expression with a call of the UDF + /// This utility allows using the UDF without requiring access to the registry. + pub fn call(&self, args: Vec<Expr>) -> Expr { + Expr::ScalarUDF { + fun: Arc::new(self.clone()), + args, + } + } +} diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index f19e9d8..cc269be 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -22,377 +22,21 @@ pub use super::Operator; use crate::error::{DataFusionError, Result}; use crate::logical_plan::ExprSchemable; use crate::logical_plan::{window_frames, DFField, DFSchema}; -use crate::physical_plan::functions::Volatility; -use crate::physical_plan::{aggregates, functions, udf::ScalarUDF, window_functions}; +use crate::physical_plan::{aggregates, functions, udf::ScalarUDF}; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::datatypes::DataType; pub use datafusion_common::{Column, ExprSchema}; -use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use datafusion_expr::AccumulatorFunctionImplementation; +use datafusion_expr::StateTypeFunction; +use datafusion_expr::{ + ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility, +}; use std::collections::HashSet; use std::fmt; use std::hash::{BuildHasher, Hash, Hasher}; -use std::ops::Not; use std::sync::Arc; -/// `Expr` is a central struct of DataFusion's query API, and -/// represent logical expressions such as `A + 1`, or `CAST(c1 AS -/// int)`. -/// -/// An `Expr` can compute its [DataType](arrow::datatypes::DataType) -/// and nullability, and has functions for building up complex -/// expressions. -/// -/// # Examples -/// -/// ## Create an expression `c1` referring to column named "c1" -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1"); -/// assert_eq!(expr, Expr::Column(Column::from_name("c1"))); -/// ``` -/// -/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together -/// ``` -/// # use datafusion::logical_plan::*; -/// let expr = col("c1") + col("c2"); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// assert_eq!(*right, col("c2")); -/// assert_eq!(op, Operator::Plus); -/// } -/// ``` -/// -/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42` -/// ``` -/// # use datafusion::logical_plan::*; -/// # use datafusion::scalar::*; -/// let expr = col("c1").eq(lit(42)); -/// -/// assert!(matches!(expr, Expr::BinaryExpr { ..} )); -/// if let Expr::BinaryExpr { left, right, op } = expr { -/// assert_eq!(*left, col("c1")); -/// let scalar = ScalarValue::Int32(Some(42)); -/// assert_eq!(*right, Expr::Literal(scalar)); -/// assert_eq!(op, Operator::Eq); -/// } -/// ``` -#[derive(Clone, PartialEq, Hash)] -pub enum Expr { - /// An expression with a specific name. - Alias(Box<Expr>, String), - /// A named reference to a qualified filed in a schema. - Column(Column), - /// A named reference to a variable in a registry. - ScalarVariable(Vec<String>), - /// A constant value. - Literal(ScalarValue), - /// A binary expression such as "age > 21" - BinaryExpr { - /// Left-hand side of the expression - left: Box<Expr>, - /// The comparison operator - op: Operator, - /// Right-hand side of the expression - right: Box<Expr>, - }, - /// Negation of an expression. The expression's type must be a boolean to make sense. - Not(Box<Expr>), - /// Whether an expression is not Null. This expression is never null. - IsNotNull(Box<Expr>), - /// Whether an expression is Null. This expression is never null. - IsNull(Box<Expr>), - /// arithmetic negation of an expression, the operand must be of a signed numeric data type - Negative(Box<Expr>), - /// Returns the field of a [`ListArray`] or [`StructArray`] by key - GetIndexedField { - /// the expression to take the field from - expr: Box<Expr>, - /// The name of the field to take - key: ScalarValue, - }, - /// Whether an expression is between a given range. - Between { - /// The value to compare - expr: Box<Expr>, - /// Whether the expression is negated - negated: bool, - /// The low end of the range - low: Box<Expr>, - /// The high end of the range - high: Box<Expr>, - }, - /// The CASE expression is similar to a series of nested if/else and there are two forms that - /// can be used. The first form consists of a series of boolean "when" expressions with - /// corresponding "then" expressions, and an optional "else" expression. - /// - /// CASE WHEN condition THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - /// - /// The second form uses a base expression and then a series of "when" clauses that match on a - /// literal value. - /// - /// CASE expression - /// WHEN value THEN result - /// [WHEN ...] - /// [ELSE result] - /// END - Case { - /// Optional base expression that can be compared to literal values in the "when" expressions - expr: Option<Box<Expr>>, - /// One or more when/then expressions - when_then_expr: Vec<(Box<Expr>, Box<Expr>)>, - /// Optional "else" expression - else_expr: Option<Box<Expr>>, - }, - /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - Cast { - /// The expression being cast - expr: Box<Expr>, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// Casts the expression to a given type and will return a null value if the expression cannot be cast. - /// This expression is guaranteed to have a fixed type. - TryCast { - /// The expression being cast - expr: Box<Expr>, - /// The `DataType` the expression will yield - data_type: DataType, - }, - /// A sort expression, that can be used to sort values. - Sort { - /// The expression to sort on - expr: Box<Expr>, - /// The direction of the sort - asc: bool, - /// Whether to put Nulls before all other data values - nulls_first: bool, - }, - /// Represents the call of a built-in scalar function with a set of arguments. - ScalarFunction { - /// The function - fun: functions::BuiltinScalarFunction, - /// List of expressions to feed to the functions as arguments - args: Vec<Expr>, - }, - /// Represents the call of a user-defined scalar function with arguments. - ScalarUDF { - /// The function - fun: Arc<ScalarUDF>, - /// List of expressions to feed to the functions as arguments - args: Vec<Expr>, - }, - /// Represents the call of an aggregate built-in function with arguments. - AggregateFunction { - /// Name of the function - fun: aggregates::AggregateFunction, - /// List of expressions to feed to the functions as arguments - args: Vec<Expr>, - /// Whether this is a DISTINCT aggregation or not - distinct: bool, - }, - /// Represents the call of a window function with arguments. - WindowFunction { - /// Name of the function - fun: window_functions::WindowFunction, - /// List of expressions to feed to the functions as arguments - args: Vec<Expr>, - /// List of partition by expressions - partition_by: Vec<Expr>, - /// List of order by expressions - order_by: Vec<Expr>, - /// Window frame - window_frame: Option<window_frames::WindowFrame>, - }, - /// aggregate function - AggregateUDF { - /// The function - fun: Arc<AggregateUDF>, - /// List of expressions to feed to the functions as arguments - args: Vec<Expr>, - }, - /// Returns whether the list contains the expr value. - InList { - /// The expression to compare - expr: Box<Expr>, - /// A list of values to compare against - list: Vec<Expr>, - /// Whether the expression is negated - negated: bool, - }, - /// Represents a reference to all fields in a schema. - Wildcard, -} - -/// Fixed seed for the hashing so that Ords are consistent across runs -const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0); - -impl PartialOrd for Expr { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - let mut hasher = SEED.build_hasher(); - self.hash(&mut hasher); - let s = hasher.finish(); - - let mut hasher = SEED.build_hasher(); - other.hash(&mut hasher); - let o = hasher.finish(); - - Some(s.cmp(&o)) - } -} - -impl Expr { - /// Returns the name of this expression based on [crate::logical_plan::DFSchema]. - /// - /// This represents how a column with this expression is named when no alias is chosen - pub fn name(&self, input_schema: &DFSchema) -> Result<String> { - create_name(self, input_schema) - } - - /// Return `self == other` - pub fn eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::Eq, other) - } - - /// Return `self != other` - pub fn not_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotEq, other) - } - - /// Return `self > other` - pub fn gt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Gt, other) - } - - /// Return `self >= other` - pub fn gt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::GtEq, other) - } - - /// Return `self < other` - pub fn lt(self, other: Expr) -> Expr { - binary_expr(self, Operator::Lt, other) - } - - /// Return `self <= other` - pub fn lt_eq(self, other: Expr) -> Expr { - binary_expr(self, Operator::LtEq, other) - } - - /// Return `self && other` - pub fn and(self, other: Expr) -> Expr { - binary_expr(self, Operator::And, other) - } - - /// Return `self || other` - pub fn or(self, other: Expr) -> Expr { - binary_expr(self, Operator::Or, other) - } - - /// Return `!self` - #[allow(clippy::should_implement_trait)] - pub fn not(self) -> Expr { - !self - } - - /// Calculate the modulus of two expressions. - /// Return `self % other` - pub fn modulus(self, other: Expr) -> Expr { - binary_expr(self, Operator::Modulo, other) - } - - /// Return `self LIKE other` - pub fn like(self, other: Expr) -> Expr { - binary_expr(self, Operator::Like, other) - } - - /// Return `self NOT LIKE other` - pub fn not_like(self, other: Expr) -> Expr { - binary_expr(self, Operator::NotLike, other) - } - - /// Return `self AS name` alias expression - pub fn alias(self, name: &str) -> Expr { - Expr::Alias(Box::new(self), name.to_owned()) - } - - /// Return `self IN <list>` if `negated` is false, otherwise - /// return `self NOT IN <list>`.a - pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr { - Expr::InList { - expr: Box::new(self), - list, - negated, - } - } - - /// Return `IsNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_null(self) -> Expr { - Expr::IsNull(Box::new(self)) - } - - /// Return `IsNotNull(Box(self)) - #[allow(clippy::wrong_self_convention)] - pub fn is_not_null(self) -> Expr { - Expr::IsNotNull(Box::new(self)) - } - - /// Create a sort expression from an existing expression. - /// - /// ``` - /// # use datafusion::logical_plan::col; - /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST - /// ``` - pub fn sort(self, asc: bool, nulls_first: bool) -> Expr { - Expr::Sort { - expr: Box::new(self), - asc, - nulls_first, - } - } -} - -impl Not for Expr { - type Output = Self; - - fn not(self) -> Self::Output { - Expr::Not(Box::new(self)) - } -} - -impl std::fmt::Display for Expr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - Expr::BinaryExpr { - ref left, - ref right, - ref op, - } => write!(f, "{} {} {}", left, op, right), - Expr::AggregateFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - /// Whether this is a DISTINCT aggregation or not - ref distinct, - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::ScalarFunction { - /// Name of the function - ref fun, - /// List of expressions to feed to the functions as arguments - ref args, - } => fmt_function(f, &fun.to_string(), false, args, true), - _ => write!(f, "{:?}", self), - } - } -} +pub use datafusion_expr::Expr; /// Helper struct for building [Expr::Case] pub struct CaseBuilder { @@ -484,15 +128,6 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder { } } -/// return a new expression l <op> r -pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr { - Expr::BinaryExpr { - left: Box::new(l), - op, - right: Box::new(r), - } -} - /// return a new expression with a logical AND pub fn and(left: Expr, right: Expr) -> Expr { Expr::BinaryExpr { @@ -934,311 +569,6 @@ pub fn create_udaf( ) } -fn fmt_function( - f: &mut fmt::Formatter, - fun: &str, - distinct: bool, - args: &[Expr], - display: bool, -) -> fmt::Result { - let args: Vec<String> = match display { - true => args.iter().map(|arg| format!("{}", arg)).collect(), - false => args.iter().map(|arg| format!("{:?}", arg)).collect(), - }; - - // let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect(); - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - write!(f, "{}({}{})", fun, distinct_str, args.join(", ")) -} - -impl fmt::Debug for Expr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias), - Expr::Column(c) => write!(f, "{}", c), - Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")), - Expr::Literal(v) => write!(f, "{:?}", v), - Expr::Case { - expr, - when_then_expr, - else_expr, - .. - } => { - write!(f, "CASE ")?; - if let Some(e) = expr { - write!(f, "{:?} ", e)?; - } - for (w, t) in when_then_expr { - write!(f, "WHEN {:?} THEN {:?} ", w, t)?; - } - if let Some(e) = else_expr { - write!(f, "ELSE {:?} ", e)?; - } - write!(f, "END") - } - Expr::Cast { expr, data_type } => { - write!(f, "CAST({:?} AS {:?})", expr, data_type) - } - Expr::TryCast { expr, data_type } => { - write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type) - } - Expr::Not(expr) => write!(f, "NOT {:?}", expr), - Expr::Negative(expr) => write!(f, "(- {:?})", expr), - Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr), - Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr), - Expr::BinaryExpr { left, op, right } => { - write!(f, "{:?} {} {:?}", left, op, right) - } - Expr::Sort { - expr, - asc, - nulls_first, - } => { - if *asc { - write!(f, "{:?} ASC", expr)?; - } else { - write!(f, "{:?} DESC", expr)?; - } - if *nulls_first { - write!(f, " NULLS FIRST") - } else { - write!(f, " NULLS LAST") - } - } - Expr::ScalarFunction { fun, args, .. } => { - fmt_function(f, &fun.to_string(), false, args, false) - } - Expr::ScalarUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - fmt_function(f, &fun.to_string(), false, args, false)?; - if !partition_by.is_empty() { - write!(f, " PARTITION BY {:?}", partition_by)?; - } - if !order_by.is_empty() { - write!(f, " ORDER BY {:?}", order_by)?; - } - if let Some(window_frame) = window_frame { - write!( - f, - " {} BETWEEN {} AND {}", - window_frame.units, - window_frame.start_bound, - window_frame.end_bound - )?; - } - Ok(()) - } - Expr::AggregateFunction { - fun, - distinct, - ref args, - .. - } => fmt_function(f, &fun.to_string(), *distinct, args, true), - Expr::AggregateUDF { fun, ref args, .. } => { - fmt_function(f, &fun.name, false, args, false) - } - Expr::Between { - expr, - negated, - low, - high, - } => { - if *negated { - write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high) - } else { - write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high) - } - } - Expr::InList { - expr, - list, - negated, - } => { - if *negated { - write!(f, "{:?} NOT IN ({:?})", expr, list) - } else { - write!(f, "{:?} IN ({:?})", expr, list) - } - } - Expr::Wildcard => write!(f, "*"), - Expr::GetIndexedField { ref expr, key } => { - write!(f, "({:?})[{}]", expr, key) - } - } - } -} - -fn create_function_name( - fun: &str, - distinct: bool, - args: &[Expr], - input_schema: &DFSchema, -) -> Result<String> { - let names: Vec<String> = args - .iter() - .map(|e| create_name(e, input_schema)) - .collect::<Result<_>>()?; - let distinct_str = match distinct { - true => "DISTINCT ", - false => "", - }; - Ok(format!("{}({}{})", fun, distinct_str, names.join(","))) -} - -/// Returns a readable name of an expression based on the input schema. -/// This function recursively transverses the expression for names such as "CAST(a > 2)". -fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> { - match e { - Expr::Alias(_, name) => Ok(name.clone()), - Expr::Column(c) => Ok(c.flat_name()), - Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")), - Expr::Literal(value) => Ok(format!("{:?}", value)), - Expr::BinaryExpr { left, op, right } => { - let left = create_name(left, input_schema)?; - let right = create_name(right, input_schema)?; - Ok(format!("{} {} {}", left, op, right)) - } - Expr::Case { - expr, - when_then_expr, - else_expr, - } => { - let mut name = "CASE ".to_string(); - if let Some(e) = expr { - let e = create_name(e, input_schema)?; - name += &format!("{} ", e); - } - for (w, t) in when_then_expr { - let when = create_name(w, input_schema)?; - let then = create_name(t, input_schema)?; - name += &format!("WHEN {} THEN {} ", when, then); - } - if let Some(e) = else_expr { - let e = create_name(e, input_schema)?; - name += &format!("ELSE {} ", e); - } - name += "END"; - Ok(name) - } - Expr::Cast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("CAST({} AS {:?})", expr, data_type)) - } - Expr::TryCast { expr, data_type } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("TRY_CAST({} AS {:?})", expr, data_type)) - } - Expr::Not(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("NOT {}", expr)) - } - Expr::Negative(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("(- {})", expr)) - } - Expr::IsNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NULL", expr)) - } - Expr::IsNotNull(expr) => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{} IS NOT NULL", expr)) - } - Expr::GetIndexedField { expr, key } => { - let expr = create_name(expr, input_schema)?; - Ok(format!("{}[{}]", expr, key)) - } - Expr::ScalarFunction { fun, args, .. } => { - create_function_name(&fun.to_string(), false, args, input_schema) - } - Expr::ScalarUDF { fun, args, .. } => { - create_function_name(&fun.name, false, args, input_schema) - } - Expr::WindowFunction { - fun, - args, - window_frame, - partition_by, - order_by, - } => { - let mut parts: Vec<String> = vec![create_function_name( - &fun.to_string(), - false, - args, - input_schema, - )?]; - if !partition_by.is_empty() { - parts.push(format!("PARTITION BY {:?}", partition_by)); - } - if !order_by.is_empty() { - parts.push(format!("ORDER BY {:?}", order_by)); - } - if let Some(window_frame) = window_frame { - parts.push(format!("{}", window_frame)); - } - Ok(parts.join(" ")) - } - Expr::AggregateFunction { - fun, - distinct, - args, - .. - } => create_function_name(&fun.to_string(), *distinct, args, input_schema), - Expr::AggregateUDF { fun, args } => { - let mut names = Vec::with_capacity(args.len()); - for e in args { - names.push(create_name(e, input_schema)?); - } - Ok(format!("{}({})", fun.name, names.join(","))) - } - Expr::InList { - expr, - list, - negated, - } => { - let expr = create_name(expr, input_schema)?; - let list = list.iter().map(|expr| create_name(expr, input_schema)); - if *negated { - Ok(format!("{} NOT IN ({:?})", expr, list)) - } else { - Ok(format!("{} IN ({:?})", expr, list)) - } - } - Expr::Between { - expr, - negated, - low, - high, - } => { - let expr = create_name(expr, input_schema)?; - let low = create_name(low, input_schema)?; - let high = create_name(high, input_schema)?; - if *negated { - Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high)) - } else { - Ok(format!("{} BETWEEN {} AND {}", expr, low, high)) - } - } - Expr::Sort { .. } => Err(DataFusionError::Internal( - "Create name does not support sort expression".to_string(), - )), - Expr::Wildcard => Err(DataFusionError::Internal( - "Create name does not support wildcard".to_string(), - )), - } -} - /// Create field meta-data from an expression, for use in a result set schema pub fn exprlist_to_fields<'a>( expr: impl IntoIterator<Item = &'a Expr>, diff --git a/datafusion/src/logical_plan/mod.rs b/datafusion/src/logical_plan/mod.rs index f2ecb0f..57e34c8 100644 --- a/datafusion/src/logical_plan/mod.rs +++ b/datafusion/src/logical_plan/mod.rs @@ -37,11 +37,12 @@ pub mod window_frames; pub use builder::{ build_join_schema, union_with_alias, LogicalPlanBuilder, UNNAMED_TABLE, }; +pub use datafusion_expr::expr::binary_expr; pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema}; pub use display::display_schema; pub use expr::{ abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan, - avg, binary_expr, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, + avg, bit_length, btrim, call_fn, case, ceil, character_length, chr, col, columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion/src/logical_plan/operators.rs index 813f7e0..132f8a8 100644 --- a/datafusion/src/logical_plan/operators.rs +++ b/datafusion/src/logical_plan/operators.rs @@ -15,75 +15,4 @@ // specific language governing permissions and limitations // under the License. -use super::{binary_expr, Expr}; pub use datafusion_expr::Operator; -use std::ops; - -impl ops::Add for Expr { - type Output = Self; - - fn add(self, rhs: Self) -> Self { - binary_expr(self, Operator::Plus, rhs) - } -} - -impl ops::Sub for Expr { - type Output = Self; - - fn sub(self, rhs: Self) -> Self { - binary_expr(self, Operator::Minus, rhs) - } -} - -impl ops::Mul for Expr { - type Output = Self; - - fn mul(self, rhs: Self) -> Self { - binary_expr(self, Operator::Multiply, rhs) - } -} - -impl ops::Div for Expr { - type Output = Self; - - fn div(self, rhs: Self) -> Self { - binary_expr(self, Operator::Divide, rhs) - } -} - -impl ops::Rem for Expr { - type Output = Self; - - fn rem(self, rhs: Self) -> Self { - binary_expr(self, Operator::Modulo, rhs) - } -} - -#[cfg(test)] -mod tests { - use crate::prelude::lit; - - #[test] - fn test_operators() { - assert_eq!( - format!("{:?}", lit(1u32) + lit(2u32)), - "UInt32(1) + UInt32(2)" - ); - assert_eq!( - format!("{:?}", lit(1u32) - lit(2u32)), - "UInt32(1) - UInt32(2)" - ); - assert_eq!( - format!("{:?}", lit(1u32) * lit(2u32)), - "UInt32(1) * UInt32(2)" - ); - assert_eq!( - format!("{:?}", lit(1u32) / lit(2u32)), - "UInt32(1) / UInt32(2)" - ); - assert_eq!( - format!("{:?}", lit(1u32) % lit(2u32)), - "UInt32(1) % UInt32(2)" - ); - } -} diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index a1531d4..656b9c8 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -40,15 +40,6 @@ use expressions::{ }; use std::sync::Arc; -/// the implementation of an aggregate function -pub type AccumulatorFunctionImplementation = - Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>; - -/// This signature corresponds to which types an aggregator serializes -/// its state, given its return datatype. -pub type StateTypeFunction = - Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>; - pub use datafusion_expr::AggregateFunction; /// Returns the datatype of the aggregate function. diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index bf0aee9..f054e5c 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -53,25 +53,12 @@ use arrow::{ record_batch::RecordBatch, }; pub use datafusion_expr::NullColumnarValue; +use datafusion_expr::ScalarFunctionImplementation; pub use datafusion_expr::{BuiltinScalarFunction, Signature, TypeSignature, Volatility}; use fmt::{Debug, Formatter}; +use std::convert::From; use std::{any::Any, fmt, sync::Arc}; -/// Scalar function -/// -/// The Fn param is the wrapped function but be aware that the function will -/// be passed with the slice / vec of columnar values (either scalar or array) -/// with the exception of zero param function, where a singular element vec -/// will be passed. In that case the single element is a null array to indicate -/// the batch's row count (so that the generative zero-argument function can know -/// the result array size). -pub type ScalarFunctionImplementation = - Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>; - -/// A function's return type -pub type ReturnTypeFunction = - Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>; - macro_rules! make_utf8_to_return_type { ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => { fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> { diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion/src/physical_plan/udaf.rs index 0de696d..e7b4058 100644 --- a/datafusion/src/physical_plan/udaf.rs +++ b/datafusion/src/physical_plan/udaf.rs @@ -28,83 +28,14 @@ use arrow::{ use crate::physical_plan::PhysicalExpr; use crate::{error::Result, logical_plan::Expr}; - +use datafusion_expr::{StateTypeFunction,ReturnTypeFunction,Signature,AccumulatorFunctionImplementation} use super::{ - aggregates::AccumulatorFunctionImplementation, - aggregates::StateTypeFunction, expressions::format_state_name, - functions::{ReturnTypeFunction, Signature}, type_coercion::coerce, Accumulator, AggregateExpr, }; use std::sync::Arc; - -/// Logical representation of a user-defined aggregate function (UDAF) -/// A UDAF is different from a UDF in that it is stateful across batches. -#[derive(Clone)] -pub struct AggregateUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - pub accumulator: AccumulatorFunctionImplementation, - /// the accumulator's state's description as a function of the return type - pub state_type: StateTypeFunction, -} - -impl Debug for AggregateUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("AggregateUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"<FUNC>") - .finish() - } -} - -impl PartialEq for AggregateUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for AggregateUDF { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl AggregateUDF { - /// Create a new AggregateUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - accumulator: &AccumulatorFunctionImplementation, - state_type: &StateTypeFunction, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - accumulator: accumulator.clone(), - state_type: state_type.clone(), - } - } - - /// creates a logical expression with a call of the UDAF - /// This utility allows using the UDAF without requiring access to the registry. - pub fn call(&self, args: Vec<Expr>) -> Expr { - Expr::AggregateUDF { - fun: Arc::new(self.clone()), - args, - } - } -} +pub use datafusion_expr::AggregateUDF; /// Creates a physical expression of the UDAF, that includes all necessary type coercion. /// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion/src/physical_plan/udf.rs index 7355746..85e6b02 100644 --- a/datafusion/src/physical_plan/udf.rs +++ b/datafusion/src/physical_plan/udf.rs @@ -33,74 +33,7 @@ use super::{ }; use std::sync::Arc; -/// Logical representation of a UDF. -#[derive(Clone)] -pub struct ScalarUDF { - /// name - pub name: String, - /// signature - pub signature: Signature, - /// Return type - pub return_type: ReturnTypeFunction, - /// actual implementation - /// - /// The fn param is the wrapped function but be aware that the function will - /// be passed with the slice / vec of columnar values (either scalar or array) - /// with the exception of zero param function, where a singular element vec - /// will be passed. In that case the single element is a null array to indicate - /// the batch's row count (so that the generative zero-argument function can know - /// the result array size). - pub fun: ScalarFunctionImplementation, -} - -impl Debug for ScalarUDF { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - f.debug_struct("ScalarUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("fun", &"<FUNC>") - .finish() - } -} - -impl PartialEq for ScalarUDF { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.signature == other.signature - } -} - -impl std::hash::Hash for ScalarUDF { - fn hash<H: std::hash::Hasher>(&self, state: &mut H) { - self.name.hash(state); - self.signature.hash(state); - } -} - -impl ScalarUDF { - /// Create a new ScalarUDF - pub fn new( - name: &str, - signature: &Signature, - return_type: &ReturnTypeFunction, - fun: &ScalarFunctionImplementation, - ) -> Self { - Self { - name: name.to_owned(), - signature: signature.clone(), - return_type: return_type.clone(), - fun: fun.clone(), - } - } - - /// creates a logical expression with a call of the UDF - /// This utility allows using the UDF without requiring access to the registry. - pub fn call(&self, args: Vec<Expr>) -> Expr { - Expr::ScalarUDF { - fun: Arc::new(self.clone()), - args, - } - } -} +pub use datafusion_expr::ScalarUDF; /// Create a physical expression of the UDF. /// This function errors when `args`' can't be coerced to a valid argument type of the UDF.