This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 5cf090a13 Fix Decimal and Floating type coerce rule (#4038)
5cf090a13 is described below
commit 5cf090a13391501c0ce7707ac7a1e50e18517b79
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Oct 31 08:08:09 2022 -0700
Fix Decimal and Floating type coerce rule (#4038)
---
benchmarks/expected-plans/q11.txt | 4 +-
benchmarks/expected-plans/q14.txt | 2 +-
benchmarks/expected-plans/q20.txt | 4 +-
datafusion/core/tests/sql/decimal.rs | 34 ++++++
datafusion/core/tests/sql/subqueries.rs | 8 +-
datafusion/expr/src/logical_plan/plan.rs | 1 +
datafusion/expr/src/type_coercion/binary.rs | 2 +
datafusion/physical-expr/src/expressions/binary.rs | 136 ++++++++++++++++++++-
8 files changed, 180 insertions(+), 11 deletions(-)
diff --git a/benchmarks/expected-plans/q11.txt
b/benchmarks/expected-plans/q11.txt
index b408340a3..0e886e2e7 100644
--- a/benchmarks/expected-plans/q11.txt
+++ b/benchmarks/expected-plans/q11.txt
@@ -1,6 +1,6 @@
Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS value
- Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 17)) > __sq_1.__value
+ Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) *
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
@@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name]
- Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value,
alias=__sq_1
+ Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS
Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/benchmarks/expected-plans/q14.txt
b/benchmarks/expected-plans/q14.txt
index c410363a5..edafe4608 100644
--- a/benchmarks/expected-plans/q14.txt
+++ b/benchmarks/expected-plans/q14.txt
@@ -1,4 +1,4 @@
-Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) *
CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%") THEN
lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS
Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice
* Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue
+Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")
THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0)
END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) -
lineitem.l_discount) AS Float64) AS promo_revenue
Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE
Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) *
CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23,
2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) -
CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38,
4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23,
2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal12
[...]
Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) *
CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23,
2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38,
4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS
Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) -
CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38,
4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS D [...]
Inner Join: lineitem.l_partkey = part.p_partkey
diff --git a/benchmarks/expected-plans/q20.txt
b/benchmarks/expected-plans/q20.txt
index e5398325e..0d095a735 100644
--- a/benchmarks/expected-plans/q20.txt
+++ b/benchmarks/expected-plans/q20.txt
@@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
- Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) >
__sq_3.__value
+ Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey,
partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name]
- Projection: lineitem.l_partkey, lineitem.l_suppkey,
Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS
Decimal128(38, 17)) AS __value, alias=__sq_3
+ Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) *
CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]],
aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766") AND
lineitem.l_shipdate < Date32("9131")
TableScan: lineitem projection=[l_partkey, l_suppkey,
l_quantity, l_shipdate]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/decimal.rs
b/datafusion/core/tests/sql/decimal.rs
index 2e3e3d2ab..e0c2c1773 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() ->
Result<()> {
assert_eq!(&DataType::Boolean, actual[0].column(0).data_type());
Ok(())
}
+
+#[tokio::test]
+async fn decimal_multiply_float() -> Result<()> {
+ let ctx = SessionContext::new();
+ let sql = "select cast(400420638.54 as decimal(12,2));";
+ let actual = execute_to_batches(&ctx, sql).await;
+
+ assert_eq!(
+ &DataType::Decimal128(12, 2),
+ actual[0].schema().field(0).data_type()
+ );
+ let expected = vec![
+ "+-----------------------+",
+ "| Float64(400420638.54) |",
+ "+-----------------------+",
+ "| 400420638.54 |",
+ "+-----------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;";
+ let actual = execute_to_batches(&ctx, sql).await;
+ assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type());
+ let expected = vec![
+ "+------------------------------------+",
+ "| Float64(400420638.54) * Float64(1) |",
+ "+------------------------------------+",
+ "| 400420638.54 |",
+ "+------------------------------------+",
+ ];
+ assert_batches_eq!(expected, &actual);
+
+ Ok(())
+}
diff --git a/datafusion/core/tests/sql/subqueries.rs
b/datafusion/core/tests/sql/subqueries.rs
index ed65d4391..4fb97d5eb 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -328,14 +328,14 @@ order by s_name;
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[nation.n_name = Utf8("CANADA")]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
- Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) >
__sq_3.__value
+ Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey,
partsupp.ps_suppkey = __sq_3.l_suppkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey,
ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
TableScan: part projection=[p_partkey, p_name],
partial_filters=[part.p_name LIKE Utf8("forest%")]
- Projection: lineitem.l_partkey, lineitem.l_suppkey,
Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS
Decimal128(38, 17)) AS __value, alias=__sq_3
+ Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) *
CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]],
aggr=[[SUM(lineitem.l_quantity)]]
Filter: lineitem.l_shipdate >= Date32("8766")
TableScan: lineitem projection=[l_partkey, l_suppkey,
l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >=
Date32("8766")]"#
@@ -443,7 +443,7 @@ order by value desc;
let actual = format!("{}", plan.display_indent());
let expected = r#"Sort: value DESC NULLS FIRST
Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost *
partsupp.ps_availqty) AS value
- Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 17)) > __sq_1.__value
+ Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
CrossJoin:
Aggregate: groupBy=[[partsupp.ps_partkey]],
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) *
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
@@ -452,7 +452,7 @@ order by value desc;
TableScan: supplier projection=[s_suppkey, s_nationkey]
Filter: nation.n_name = Utf8("GERMANY")
TableScan: nation projection=[n_nationkey, n_name],
partial_filters=[nation.n_name = Utf8("GERMANY")]
- Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value,
alias=__sq_1
+ Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS
Float64) * Float64(0.0001) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS
Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/datafusion/expr/src/logical_plan/plan.rs
b/datafusion/expr/src/logical_plan/plan.rs
index d65ed5228..ce169f6ec 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1488,6 +1488,7 @@ impl Subquery {
pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
match plan {
Expr::ScalarSubquery(it) => Ok(it),
+ Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
_ => plan_err!("Could not coerce into ScalarSubquery!"),
}
}
diff --git a/datafusion/expr/src/type_coercion/binary.rs
b/datafusion/expr/src/type_coercion/binary.rs
index 2a125d56b..45510cb03 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -333,6 +333,8 @@ fn mathematics_numerical_coercion(
(Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _),
Null) => {
Some(dec_type.clone())
}
+ (Decimal128(_, _), Float32 | Float64) => Some(Float64),
+ (Float32 | Float64, Decimal128(_, _)) => Some(Float64),
(Decimal128(_, _), _) => {
let converted_decimal_type =
coerce_numeric_type_to_decimal(rhs_type);
match converted_decimal_type {
diff --git a/datafusion/physical-expr/src/expressions/binary.rs
b/datafusion/physical-expr/src/expressions/binary.rs
index 8b93f49c1..4aa5d5e1e 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -2574,9 +2574,22 @@ mod tests {
let right_expr = if right.data_type().eq(&op_type) {
col("b", schema)?
} else {
- try_cast(col("b", schema)?, schema, op_type)?
+ try_cast(col("b", schema)?, schema, op_type.clone())?
};
- let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
+
+ let coerced_schema = Schema::new(vec![
+ Field::new(
+ schema.field(0).name(),
+ op_type.clone(),
+ schema.field(0).is_nullable(),
+ ),
+ Field::new(
+ schema.field(1).name(),
+ op_type,
+ schema.field(1).is_nullable(),
+ ),
+ ]);
+ let arithmetic_op = binary_simple(left_expr, op, right_expr,
&coerced_schema);
let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
let batch = RecordBatch::try_new(schema.clone(), data)?;
let result =
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -2704,6 +2717,125 @@ mod tests {
Ok(())
}
+ #[test]
+ fn arithmetic_decimal_float_expr_test() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float64, true),
+ Field::new("b", DataType::Decimal128(10, 2), true),
+ ]));
+ let value: i128 = 123;
+ let decimal_array = Arc::new(create_decimal_array(
+ &[
+ Some(value as i128), // 1.23
+ None,
+ Some((value - 1) as i128), // 1.22
+ Some((value + 1) as i128), // 1.24
+ ],
+ 10,
+ 2,
+ )) as ArrayRef;
+ let float64_array = Arc::new(Float64Array::from(vec![
+ Some(123.0),
+ Some(122.0),
+ Some(123.0),
+ Some(124.0),
+ ])) as ArrayRef;
+
+ // add: float64 array add decimal array
+ let expect = Arc::new(Float64Array::from(vec![
+ Some(124.23),
+ None,
+ Some(124.22),
+ Some(125.24),
+ ])) as ArrayRef;
+ apply_arithmetic_op(
+ &schema,
+ &float64_array,
+ &decimal_array,
+ Operator::Plus,
+ expect,
+ )
+ .unwrap();
+
+ // subtract: decimal array subtract float64 array
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float64, true),
+ Field::new("b", DataType::Decimal128(10, 2), true),
+ ]));
+ let expect = Arc::new(Float64Array::from(vec![
+ Some(121.77),
+ None,
+ Some(121.78),
+ Some(122.76),
+ ])) as ArrayRef;
+ apply_arithmetic_op(
+ &schema,
+ &float64_array,
+ &decimal_array,
+ Operator::Minus,
+ expect,
+ )
+ .unwrap();
+
+ // multiply: decimal array multiply float64 array
+ let expect = Arc::new(Float64Array::from(vec![
+ Some(151.29),
+ None,
+ Some(150.06),
+ Some(153.76),
+ ])) as ArrayRef;
+ apply_arithmetic_op(
+ &schema,
+ &float64_array,
+ &decimal_array,
+ Operator::Multiply,
+ expect,
+ )
+ .unwrap();
+
+ // divide: float64 array divide decimal array
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float64, true),
+ Field::new("b", DataType::Decimal128(10, 2), true),
+ ]));
+ let expect = Arc::new(Float64Array::from(vec![
+ Some(100.0),
+ None,
+ Some(100.81967213114754),
+ Some(100.0),
+ ])) as ArrayRef;
+ apply_arithmetic_op(
+ &schema,
+ &float64_array,
+ &decimal_array,
+ Operator::Divide,
+ expect,
+ )
+ .unwrap();
+
+ // modulus: float64 array modulus decimal array
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float64, true),
+ Field::new("b", DataType::Decimal128(10, 2), true),
+ ]));
+ let expect = Arc::new(Float64Array::from(vec![
+ Some(1.7763568394002505e-15),
+ None,
+ Some(1.0000000000000027),
+ Some(8.881784197001252e-16),
+ ])) as ArrayRef;
+ apply_arithmetic_op(
+ &schema,
+ &float64_array,
+ &decimal_array,
+ Operator::Modulo,
+ expect,
+ )
+ .unwrap();
+
+ Ok(())
+ }
+
#[test]
fn bitwise_array_test() -> Result<()> {
let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)]))
as ArrayRef;