This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 5cf090a13 Fix Decimal and Floating type coerce rule (#4038)
5cf090a13 is described below

commit 5cf090a13391501c0ce7707ac7a1e50e18517b79
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Oct 31 08:08:09 2022 -0700

    Fix Decimal and Floating type coerce rule (#4038)
---
 benchmarks/expected-plans/q11.txt                  |   4 +-
 benchmarks/expected-plans/q14.txt                  |   2 +-
 benchmarks/expected-plans/q20.txt                  |   4 +-
 datafusion/core/tests/sql/decimal.rs               |  34 ++++++
 datafusion/core/tests/sql/subqueries.rs            |   8 +-
 datafusion/expr/src/logical_plan/plan.rs           |   1 +
 datafusion/expr/src/type_coercion/binary.rs        |   2 +
 datafusion/physical-expr/src/expressions/binary.rs | 136 ++++++++++++++++++++-
 8 files changed, 180 insertions(+), 11 deletions(-)

diff --git a/benchmarks/expected-plans/q11.txt 
b/benchmarks/expected-plans/q11.txt
index b408340a3..0e886e2e7 100644
--- a/benchmarks/expected-plans/q11.txt
+++ b/benchmarks/expected-plans/q11.txt
@@ -1,6 +1,6 @@
 Sort: value DESC NULLS FIRST
   Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * 
partsupp.ps_availqty) AS value
-    Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 17)) > __sq_1.__value
+    Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
       CrossJoin:
         Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * 
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
           Inner Join: supplier.s_nationkey = nation.n_nationkey
@@ -9,7 +9,7 @@ Sort: value DESC NULLS FIRST
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name]
-        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, 
alias=__sq_1
+        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Float64) * Float64(0.0001) AS __value, alias=__sq_1
           Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS 
Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: supplier.s_nationkey = nation.n_nationkey
               Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/benchmarks/expected-plans/q14.txt 
b/benchmarks/expected-plans/q14.txt
index c410363a5..edafe4608 100644
--- a/benchmarks/expected-plans/q14.txt
+++ b/benchmarks/expected-plans/q14.txt
@@ -1,4 +1,4 @@
-Projection: CAST(Decimal128(Some(1000000000000000000000),38,19) * 
CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")  THEN 
lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) END) AS 
Decimal128(38, 19)) AS Decimal128(38, 38)) / CAST(SUM(lineitem.l_extendedprice 
* Int64(1) - lineitem.l_discount) AS Decimal128(38, 38)) AS promo_revenue
+Projection: Float64(100) * CAST(SUM(CASE WHEN part.p_type LIKE Utf8("PROMO%")  
THEN lineitem.l_extendedprice * Int64(1) - lineitem.l_discount ELSE Int64(0) 
END) AS Float64) / CAST(SUM(lineitem.l_extendedprice * Int64(1) - 
lineitem.l_discount) AS Float64) AS promo_revenue
   Aggregate: groupBy=[[]], aggr=[[SUM(CASE WHEN part.p_type LIKE 
Utf8("PROMO%") THEN CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * 
CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 
2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - 
CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 
4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 
2))CAST(lineitem.l_discount AS Decimal128(23, 2))lineitem.l_discountDecimal12 
[...]
     Projection: CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * 
CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 
2)) AS Decimal128(38, 4)) AS CAST(lineitem.l_extendedprice AS Decimal128(38, 
4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS 
Decimal128(23, 2)) AS Decimal128(38, 4))CAST(Decimal128(Some(100),23,2) - 
CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 
4))Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS D [...]
       Inner Join: lineitem.l_partkey = part.p_partkey
diff --git a/benchmarks/expected-plans/q20.txt 
b/benchmarks/expected-plans/q20.txt
index e5398325e..0d095a735 100644
--- a/benchmarks/expected-plans/q20.txt
+++ b/benchmarks/expected-plans/q20.txt
@@ -6,14 +6,14 @@ Sort: supplier.s_name ASC NULLS LAST
         Filter: nation.n_name = Utf8("CANADA")
           TableScan: nation projection=[n_nationkey, n_name]
       Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
-        Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > 
__sq_3.__value
+        Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
           Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, 
partsupp.ps_suppkey = __sq_3.l_suppkey
             LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
               Projection: part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name]
-            Projection: lineitem.l_partkey, lineitem.l_suppkey, 
Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
+            Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * 
CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
               Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], 
aggr=[[SUM(lineitem.l_quantity)]]
                 Filter: lineitem.l_shipdate >= Date32("8766") AND 
lineitem.l_shipdate < Date32("9131")
                   TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate]
\ No newline at end of file
diff --git a/datafusion/core/tests/sql/decimal.rs 
b/datafusion/core/tests/sql/decimal.rs
index 2e3e3d2ab..e0c2c1773 100644
--- a/datafusion/core/tests/sql/decimal.rs
+++ b/datafusion/core/tests/sql/decimal.rs
@@ -879,3 +879,37 @@ async fn decimal_null_array_scalar_comparison() -> 
Result<()> {
     assert_eq!(&DataType::Boolean, actual[0].column(0).data_type());
     Ok(())
 }
+
+#[tokio::test]
+async fn decimal_multiply_float() -> Result<()> {
+    let ctx = SessionContext::new();
+    let sql = "select cast(400420638.54 as decimal(12,2));";
+    let actual = execute_to_batches(&ctx, sql).await;
+
+    assert_eq!(
+        &DataType::Decimal128(12, 2),
+        actual[0].schema().field(0).data_type()
+    );
+    let expected = vec![
+        "+-----------------------+",
+        "| Float64(400420638.54) |",
+        "+-----------------------+",
+        "| 400420638.54          |",
+        "+-----------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    let sql = "select cast(400420638.54 as decimal(12,2)) * 1.0;";
+    let actual = execute_to_batches(&ctx, sql).await;
+    assert_eq!(&DataType::Float64, actual[0].schema().field(0).data_type());
+    let expected = vec![
+        "+------------------------------------+",
+        "| Float64(400420638.54) * Float64(1) |",
+        "+------------------------------------+",
+        "| 400420638.54                       |",
+        "+------------------------------------+",
+    ];
+    assert_batches_eq!(expected, &actual);
+
+    Ok(())
+}
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index ed65d4391..4fb97d5eb 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -328,14 +328,14 @@ order by s_name;
         Filter: nation.n_name = Utf8("CANADA")
           TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[nation.n_name = Utf8("CANADA")]
       Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
-        Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > 
__sq_3.__value
+        Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value
           Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, 
partsupp.ps_suppkey = __sq_3.l_suppkey
             LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
               TableScan: partsupp projection=[ps_partkey, ps_suppkey, 
ps_availqty]
               Projection: part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name], 
partial_filters=[part.p_name LIKE Utf8("forest%")]
-            Projection: lineitem.l_partkey, lineitem.l_suppkey, 
Decimal128(Some(50000000000000000),38,17) * CAST(SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
+            Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * 
CAST(SUM(lineitem.l_quantity) AS Float64) AS __value, alias=__sq_3
               Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], 
aggr=[[SUM(lineitem.l_quantity)]]
                 Filter: lineitem.l_shipdate >= Date32("8766")
                   TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= 
Date32("8766")]"#
@@ -443,7 +443,7 @@ order by value desc;
     let actual = format!("{}", plan.display_indent());
     let expected = r#"Sort: value DESC NULLS FIRST
   Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * 
partsupp.ps_availqty) AS value
-    Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 17)) > __sq_1.__value
+    Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 15)) > CAST(__sq_1.__value AS Decimal128(38, 15))
       CrossJoin:
         Aggregate: groupBy=[[partsupp.ps_partkey]], 
aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * 
CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
           Inner Join: supplier.s_nationkey = nation.n_nationkey
@@ -452,7 +452,7 @@ order by value desc;
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[nation.n_name = Utf8("GERMANY")]
-        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, 
alias=__sq_1
+        Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS 
Float64) * Float64(0.0001) AS __value, alias=__sq_1
           Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS 
Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: supplier.s_nationkey = nation.n_nationkey
               Inner Join: partsupp.ps_suppkey = supplier.s_suppkey
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index d65ed5228..ce169f6ec 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1488,6 +1488,7 @@ impl Subquery {
     pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
         match plan {
             Expr::ScalarSubquery(it) => Ok(it),
+            Expr::Cast(cast) => Subquery::try_from_expr(cast.expr.as_ref()),
             _ => plan_err!("Could not coerce into ScalarSubquery!"),
         }
     }
diff --git a/datafusion/expr/src/type_coercion/binary.rs 
b/datafusion/expr/src/type_coercion/binary.rs
index 2a125d56b..45510cb03 100644
--- a/datafusion/expr/src/type_coercion/binary.rs
+++ b/datafusion/expr/src/type_coercion/binary.rs
@@ -333,6 +333,8 @@ fn mathematics_numerical_coercion(
         (Null, dec_type @ Decimal128(_, _)) | (dec_type @ Decimal128(_, _), 
Null) => {
             Some(dec_type.clone())
         }
+        (Decimal128(_, _), Float32 | Float64) => Some(Float64),
+        (Float32 | Float64, Decimal128(_, _)) => Some(Float64),
         (Decimal128(_, _), _) => {
             let converted_decimal_type = 
coerce_numeric_type_to_decimal(rhs_type);
             match converted_decimal_type {
diff --git a/datafusion/physical-expr/src/expressions/binary.rs 
b/datafusion/physical-expr/src/expressions/binary.rs
index 8b93f49c1..4aa5d5e1e 100644
--- a/datafusion/physical-expr/src/expressions/binary.rs
+++ b/datafusion/physical-expr/src/expressions/binary.rs
@@ -2574,9 +2574,22 @@ mod tests {
         let right_expr = if right.data_type().eq(&op_type) {
             col("b", schema)?
         } else {
-            try_cast(col("b", schema)?, schema, op_type)?
+            try_cast(col("b", schema)?, schema, op_type.clone())?
         };
-        let arithmetic_op = binary_simple(left_expr, op, right_expr, schema);
+
+        let coerced_schema = Schema::new(vec![
+            Field::new(
+                schema.field(0).name(),
+                op_type.clone(),
+                schema.field(0).is_nullable(),
+            ),
+            Field::new(
+                schema.field(1).name(),
+                op_type,
+                schema.field(1).is_nullable(),
+            ),
+        ]);
+        let arithmetic_op = binary_simple(left_expr, op, right_expr, 
&coerced_schema);
         let data: Vec<ArrayRef> = vec![left.clone(), right.clone()];
         let batch = RecordBatch::try_new(schema.clone(), data)?;
         let result = 
arithmetic_op.evaluate(&batch)?.into_array(batch.num_rows());
@@ -2704,6 +2717,125 @@ mod tests {
         Ok(())
     }
 
+    #[test]
+    fn arithmetic_decimal_float_expr_test() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float64, true),
+            Field::new("b", DataType::Decimal128(10, 2), true),
+        ]));
+        let value: i128 = 123;
+        let decimal_array = Arc::new(create_decimal_array(
+            &[
+                Some(value as i128), // 1.23
+                None,
+                Some((value - 1) as i128), // 1.22
+                Some((value + 1) as i128), // 1.24
+            ],
+            10,
+            2,
+        )) as ArrayRef;
+        let float64_array = Arc::new(Float64Array::from(vec![
+            Some(123.0),
+            Some(122.0),
+            Some(123.0),
+            Some(124.0),
+        ])) as ArrayRef;
+
+        // add: float64 array add decimal array
+        let expect = Arc::new(Float64Array::from(vec![
+            Some(124.23),
+            None,
+            Some(124.22),
+            Some(125.24),
+        ])) as ArrayRef;
+        apply_arithmetic_op(
+            &schema,
+            &float64_array,
+            &decimal_array,
+            Operator::Plus,
+            expect,
+        )
+        .unwrap();
+
+        // subtract: decimal array subtract float64 array
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float64, true),
+            Field::new("b", DataType::Decimal128(10, 2), true),
+        ]));
+        let expect = Arc::new(Float64Array::from(vec![
+            Some(121.77),
+            None,
+            Some(121.78),
+            Some(122.76),
+        ])) as ArrayRef;
+        apply_arithmetic_op(
+            &schema,
+            &float64_array,
+            &decimal_array,
+            Operator::Minus,
+            expect,
+        )
+        .unwrap();
+
+        // multiply: decimal array multiply float64 array
+        let expect = Arc::new(Float64Array::from(vec![
+            Some(151.29),
+            None,
+            Some(150.06),
+            Some(153.76),
+        ])) as ArrayRef;
+        apply_arithmetic_op(
+            &schema,
+            &float64_array,
+            &decimal_array,
+            Operator::Multiply,
+            expect,
+        )
+        .unwrap();
+
+        // divide: float64 array divide decimal array
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float64, true),
+            Field::new("b", DataType::Decimal128(10, 2), true),
+        ]));
+        let expect = Arc::new(Float64Array::from(vec![
+            Some(100.0),
+            None,
+            Some(100.81967213114754),
+            Some(100.0),
+        ])) as ArrayRef;
+        apply_arithmetic_op(
+            &schema,
+            &float64_array,
+            &decimal_array,
+            Operator::Divide,
+            expect,
+        )
+        .unwrap();
+
+        // modulus: float64 array modulus decimal array
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float64, true),
+            Field::new("b", DataType::Decimal128(10, 2), true),
+        ]));
+        let expect = Arc::new(Float64Array::from(vec![
+            Some(1.7763568394002505e-15),
+            None,
+            Some(1.0000000000000027),
+            Some(8.881784197001252e-16),
+        ])) as ArrayRef;
+        apply_arithmetic_op(
+            &schema,
+            &float64_array,
+            &decimal_array,
+            Operator::Modulo,
+            expect,
+        )
+        .unwrap();
+
+        Ok(())
+    }
+
     #[test]
     fn bitwise_array_test() -> Result<()> {
         let left = Arc::new(Int32Array::from(vec![Some(12), None, Some(11)])) 
as ArrayRef;

Reply via email to