This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f05df7beaf functions: support trunc() function with one or two args
(#6942)
f05df7beaf is described below
commit f05df7beaf09486be2544ab6e397a480651a2778
Author: Syleechan <[email protected]>
AuthorDate: Fri Jul 21 01:10:07 2023 +0800
functions: support trunc() function with one or two args (#6942)
* functions: support trunc() function with one or two args
* format code style
* modify truncate method
* adjust code format
* format code
* fix sql test error
---
.../core/tests/sqllogictests/test_files/scalar.slt | 6 +
datafusion/expr/src/built_in_function.rs | 10 +-
datafusion/expr/src/expr_fn.rs | 9 +-
datafusion/physical-expr/src/functions.rs | 4 +-
datafusion/physical-expr/src/math_expressions.rs | 144 ++++++++++++++++++++-
datafusion/proto/src/logical_plan/from_proto.rs | 7 +-
6 files changed, 173 insertions(+), 7 deletions(-)
diff --git a/datafusion/core/tests/sqllogictests/test_files/scalar.slt
b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
index 8c5c399c39..6e563a671d 100644
--- a/datafusion/core/tests/sqllogictests/test_files/scalar.slt
+++ b/datafusion/core/tests/sqllogictests/test_files/scalar.slt
@@ -912,6 +912,12 @@ select trunc(a), trunc(b), trunc(c) from small_floats;
0 0 0
0 0 1
+# trunc with precision
+query RRRRR rowsort
+select trunc(4.267, 3), trunc(1.1234, 2), trunc(-1.1231, 6), trunc(1.2837284,
2), trunc(1.1, 0);
+----
+4.267 1.12 -1.1231 1.28 1
+
## bitwise and
# bitwise and with column and scalar
diff --git a/datafusion/expr/src/built_in_function.rs
b/datafusion/expr/src/built_in_function.rs
index 74561d9fd7..66c20d362e 100644
--- a/datafusion/expr/src/built_in_function.rs
+++ b/datafusion/expr/src/built_in_function.rs
@@ -1072,6 +1072,15 @@ impl BuiltinScalarFunction {
],
self.volatility(),
),
+ BuiltinScalarFunction::Trunc => Signature::one_of(
+ vec![
+ Exact(vec![Float32, Int64]),
+ Exact(vec![Float64, Int64]),
+ Exact(vec![Float64]),
+ Exact(vec![Float32]),
+ ],
+ self.volatility(),
+ ),
BuiltinScalarFunction::Atan2 => Signature::one_of(
vec![Exact(vec![Float32, Float32]), Exact(vec![Float64,
Float64])],
self.volatility(),
@@ -1116,7 +1125,6 @@ impl BuiltinScalarFunction {
| BuiltinScalarFunction::Sqrt
| BuiltinScalarFunction::Tan
| BuiltinScalarFunction::Tanh
- | BuiltinScalarFunction::Trunc
| BuiltinScalarFunction::Cot => {
// math expressions expect 1 argument of type f64 or f32
// priority is given to f64 because e.g. `sqrt(1i32)` is in IR
(real numbers) and thus we
diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs
index b175fc6f51..30d9580c42 100644
--- a/datafusion/expr/src/expr_fn.rs
+++ b/datafusion/expr/src/expr_fn.rs
@@ -502,7 +502,11 @@ scalar_expr!(
scalar_expr!(Degrees, degrees, num, "converts radians to degrees");
scalar_expr!(Radians, radians, num, "converts degrees to radians");
nary_scalar_expr!(Round, round, "round to nearest integer");
-scalar_expr!(Trunc, trunc, num, "truncate toward zero");
+nary_scalar_expr!(
+ Trunc,
+ trunc,
+ "truncate toward zero, with optional precision"
+);
scalar_expr!(Abs, abs, num, "absolute value");
scalar_expr!(Signum, signum, num, "sign of the argument (-1, 0, +1) ");
scalar_expr!(Exp, exp, num, "exponential");
@@ -929,7 +933,8 @@ mod test {
test_unary_scalar_expr!(Radians, radians);
test_nary_scalar_expr!(Round, round, input);
test_nary_scalar_expr!(Round, round, input, decimal_places);
- test_unary_scalar_expr!(Trunc, trunc);
+ test_nary_scalar_expr!(Trunc, trunc, num);
+ test_nary_scalar_expr!(Trunc, trunc, num, precision);
test_unary_scalar_expr!(Abs, abs);
test_unary_scalar_expr!(Signum, signum);
test_unary_scalar_expr!(Exp, exp);
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index a92d4335d4..14279d7006 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -392,7 +392,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::Cbrt => Arc::new(math_expressions::cbrt),
BuiltinScalarFunction::Tan => Arc::new(math_expressions::tan),
BuiltinScalarFunction::Tanh => Arc::new(math_expressions::tanh),
- BuiltinScalarFunction::Trunc => Arc::new(math_expressions::trunc),
+ BuiltinScalarFunction::Trunc => {
+ Arc::new(|args|
make_scalar_function(math_expressions::trunc)(args))
+ }
BuiltinScalarFunction::Pi => Arc::new(math_expressions::pi),
BuiltinScalarFunction::Power => {
Arc::new(|args|
make_scalar_function(math_expressions::power)(args))
diff --git a/datafusion/physical-expr/src/math_expressions.rs
b/datafusion/physical-expr/src/math_expressions.rs
index 9a4653c8a0..883c016c04 100644
--- a/datafusion/physical-expr/src/math_expressions.rs
+++ b/datafusion/physical-expr/src/math_expressions.rs
@@ -21,7 +21,7 @@ use arrow::array::ArrayRef;
use arrow::array::{Float32Array, Float64Array, Int64Array};
use arrow::datatypes::DataType;
use datafusion_common::ScalarValue;
-use datafusion_common::ScalarValue::Float32;
+use datafusion_common::ScalarValue::{Float32, Int64};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use rand::{thread_rng, Rng};
@@ -158,7 +158,6 @@ math_unary_function!("acosh", acosh);
math_unary_function!("atanh", atanh);
math_unary_function!("floor", floor);
math_unary_function!("ceil", ceil);
-math_unary_function!("trunc", trunc);
math_unary_function!("abs", abs);
math_unary_function!("signum", signum);
math_unary_function!("exp", exp);
@@ -530,6 +529,75 @@ fn compute_cot64(x: f64) -> f64 {
1.0 / a
}
+/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function
+pub fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
+ if args.len() != 1 && args.len() != 2 {
+ return Err(DataFusionError::Internal(format!(
+ "truncate function requires one or two arguments, got {}",
+ args.len()
+ )));
+ }
+
+ //if only one arg then invoke toolchain trunc(num) and precision = 0 by
default
+ //or then invoke the compute_truncate method to process precision
+ let num = &args[0];
+ let precision = if args.len() == 1 {
+ ColumnarValue::Scalar(Int64(Some(0)))
+ } else {
+ ColumnarValue::Array(args[1].clone())
+ };
+
+ match args[0].data_type() {
+ DataType::Float64 => match precision {
+ ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
+ make_function_scalar_inputs!(num, "num", Float64Array, {
f64::trunc }),
+ ) as ArrayRef),
+ ColumnarValue::Array(precision) =>
Ok(Arc::new(make_function_inputs2!(
+ num,
+ precision,
+ "x",
+ "y",
+ Float64Array,
+ Int64Array,
+ { compute_truncate64 }
+ )) as ArrayRef),
+ _ => Err(DataFusionError::Internal(
+ "trunc function requires a scalar or array for
precision".to_string(),
+ )),
+ },
+ DataType::Float32 => match precision {
+ ColumnarValue::Scalar(Int64(Some(0))) => Ok(Arc::new(
+ make_function_scalar_inputs!(num, "num", Float32Array, {
f32::trunc }),
+ ) as ArrayRef),
+ ColumnarValue::Array(precision) =>
Ok(Arc::new(make_function_inputs2!(
+ num,
+ precision,
+ "x",
+ "y",
+ Float32Array,
+ Int64Array,
+ { compute_truncate32 }
+ )) as ArrayRef),
+ _ => Err(DataFusionError::Internal(
+ "trunc function requires a scalar or array for
precision".to_string(),
+ )),
+ },
+ other => Err(DataFusionError::Internal(format!(
+ "Unsupported data type {other:?} for function trunc"
+ ))),
+ }
+}
+
+fn compute_truncate32(x: f32, y: i64) -> f32 {
+ let factor = 10.0_f32.powi(y as i32);
+ (x * factor).round() / factor
+}
+
+fn compute_truncate64(x: f64, y: i64) -> f64 {
+ let factor = 10.0_f64.powi(y as i32);
+ (x * factor).round() / factor
+}
+
#[cfg(test)]
mod tests {
@@ -818,4 +886,76 @@ mod tests {
assert!((floats.value(2) - expected.value(2)).abs() < eps);
assert!((floats.value(3) - expected.value(3)).abs() < eps);
}
+
+ #[test]
+ fn test_truncate_32() {
+ let args: Vec<ArrayRef> = vec![
+ Arc::new(Float32Array::from(vec![
+ 15.0,
+ 1_234.267_8,
+ 1_233.123_4,
+ 3.312_979_2,
+ -21.123_4,
+ ])),
+ Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
+ ];
+
+ let result = trunc(&args).expect("failed to initialize function
truncate");
+ let floats =
+ as_float32_array(&result).expect("failed to initialize function
truncate");
+
+ assert_eq!(floats.len(), 5);
+ assert_eq!(floats.value(0), 15.0);
+ assert_eq!(floats.value(1), 1_234.268);
+ assert_eq!(floats.value(2), 1_233.12);
+ assert_eq!(floats.value(3), 3.312_98);
+ assert_eq!(floats.value(4), -21.123_4);
+ }
+
+ #[test]
+ fn test_truncate_64() {
+ let args: Vec<ArrayRef> = vec![
+ Arc::new(Float64Array::from(vec![
+ 5.0,
+ 234.267_812_176,
+ 123.123_456_789,
+ 123.312_979_313_2,
+ -321.123_1,
+ ])),
+ Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
+ ];
+
+ let result = trunc(&args).expect("failed to initialize function
truncate");
+ let floats =
+ as_float64_array(&result).expect("failed to initialize function
truncate");
+
+ assert_eq!(floats.len(), 5);
+ assert_eq!(floats.value(0), 5.0);
+ assert_eq!(floats.value(1), 234.268);
+ assert_eq!(floats.value(2), 123.12);
+ assert_eq!(floats.value(3), 123.312_98);
+ assert_eq!(floats.value(4), -321.123_1);
+ }
+
+ #[test]
+ fn test_truncate_64_one_arg() {
+ let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
+ 5.0,
+ 234.267_812,
+ 123.123_45,
+ 123.312_979_313_2,
+ -321.123,
+ ]))];
+
+ let result = trunc(&args).expect("failed to initialize function
truncate");
+ let floats =
+ as_float64_array(&result).expect("failed to initialize function
truncate");
+
+ assert_eq!(floats.len(), 5);
+ assert_eq!(floats.value(0), 5.0);
+ assert_eq!(floats.value(1), 234.0);
+ assert_eq!(floats.value(2), 123.0);
+ assert_eq!(floats.value(3), 123.0);
+ assert_eq!(floats.value(4), -321.0);
+ }
}
diff --git a/datafusion/proto/src/logical_plan/from_proto.rs
b/datafusion/proto/src/logical_plan/from_proto.rs
index 202de7df08..a3718090ed 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -1293,7 +1293,12 @@ pub fn parse_expr(
.map(|expr| parse_expr(expr, registry))
.collect::<Result<Vec<_>, _>>()?,
)),
- ScalarFunction::Trunc => Ok(trunc(parse_expr(&args[0],
registry)?)),
+ ScalarFunction::Trunc => Ok(trunc(
+ args.to_owned()
+ .iter()
+ .map(|expr| parse_expr(expr, registry))
+ .collect::<Result<Vec<_>, _>>()?,
+ )),
ScalarFunction::Abs => Ok(abs(parse_expr(&args[0],
registry)?)),
ScalarFunction::Signum => Ok(signum(parse_expr(&args[0],
registry)?)),
ScalarFunction::OctetLength => {