This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new bec385668 Enable more benchmark verification tests (#4044)
bec385668 is described below

commit bec3856688ac3dc3fdda927f26640dc028023436
Author: Andy Grove <[email protected]>
AuthorDate: Mon Oct 31 12:16:25 2022 -0600

    Enable more benchmark verification tests (#4044)
    
    * Fix Decimal and Floating type coerce rule
    
    * Enable more queries in benchmark verification tests
    
    * update comparison_binary_numeric_coercion
    
    * revert type coercin change in comparison_binary_numeric_coercion
    
    * smaller tolerance
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
---
 benchmarks/src/bin/tpch.rs | 35 ++++++++++++++++++++-------
 benchmarks/src/tpch.rs     | 59 ++++++++++++++++++++++++++++++----------------
 2 files changed, 66 insertions(+), 28 deletions(-)

diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index b9afe4d6a..df64537bd 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -668,7 +668,6 @@ mod tests {
     }
 
     #[cfg(feature = "ci")]
-    #[ignore] // TODO produces correct result but has rounding error
     #[tokio::test]
     async fn verify_q9() -> Result<()> {
         verify_query(9).await
@@ -681,7 +680,6 @@ mod tests {
     }
 
     #[cfg(feature = "ci")]
-    #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023
     #[tokio::test]
     async fn verify_q11() -> Result<()> {
         verify_query(11).await
@@ -700,7 +698,6 @@ mod tests {
     }
 
     #[cfg(feature = "ci")]
-    #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025
     #[tokio::test]
     async fn verify_q14() -> Result<()> {
         verify_query(14).await
@@ -719,7 +716,6 @@ mod tests {
     }
 
     #[cfg(feature = "ci")]
-    #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026
     #[tokio::test]
     async fn verify_q17() -> Result<()> {
         verify_query(17).await
@@ -896,8 +892,8 @@ mod tests {
     #[cfg(feature = "ci")]
     async fn verify_query(n: usize) -> Result<()> {
         use datafusion::arrow::datatypes::{DataType, Field};
+        use datafusion::common::ScalarValue;
         use datafusion::logical_expr::expr::Cast;
-        use datafusion::logical_expr::Expr;
         use std::env;
 
         let path = 
env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string());
@@ -990,7 +986,12 @@ mod tests {
                     }
                     data_type => data_type == e.data_type(),
                 });
-        assert!(schema_matches);
+        if !schema_matches {
+            panic!(
+                "expected_fields: {:?}\ntransformed_fields: {:?}",
+                expected_fields, transformed_fields
+            )
+        }
 
         // convert both datasets to Vec<Vec<String>> for simple comparison
         let expected_vec = result_vec(&expected);
@@ -1000,8 +1001,26 @@ mod tests {
         assert_eq!(expected_vec.len(), actual_vec.len());
 
         // compare each row. this works as all TPC-H queries have 
deterministically ordered results
-        for i in 0..actual_vec.len() {
-            assert_eq!(expected_vec[i], actual_vec[i]);
+        for i in 0..expected_vec.len() {
+            let expected_row = &expected_vec[i];
+            let actual_row = &actual_vec[i];
+            assert_eq!(expected_row.len(), actual_row.len());
+
+            for j in 0..expected.len() {
+                match (&expected_row[j], &actual_row[j]) {
+                    (ScalarValue::Float64(Some(l)), 
ScalarValue::Float64(Some(r))) => {
+                        // allow for rounding errors until we move to decimal 
types
+                        let tolerance = 0.1;
+                        if (l - r).abs() > tolerance {
+                            panic!(
+                                "Expected: {}; Actual: {}; Tolerance: {}",
+                                l, r, tolerance
+                            )
+                        }
+                    }
+                    (l, r) => assert_eq!(format!("{:?}", l), format!("{:?}", 
r)),
+                }
+            }
         }
 
         Ok(())
diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
index 46c53edf1..ad61de8a3 100644
--- a/benchmarks/src/tpch.rs
+++ b/benchmarks/src/tpch.rs
@@ -15,7 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use arrow::array::ArrayRef;
+use arrow::array::{
+    Array, ArrayRef, Date32Array, Decimal128Array, Float64Array, Int32Array, 
Int64Array,
+    StringArray,
+};
 use arrow::record_batch::RecordBatch;
 use std::fs;
 use std::ops::{Div, Mul};
@@ -23,7 +26,7 @@ use std::path::Path;
 use std::sync::Arc;
 use std::time::Instant;
 
-use datafusion::arrow::util::display::array_value_to_string;
+use datafusion::common::ScalarValue;
 use datafusion::logical_expr::Cast;
 use datafusion::prelude::*;
 use datafusion::{
@@ -229,11 +232,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
             Field::new("custdist", DataType::Int64, true),
         ]),
 
-        14 => Schema::new(vec![Field::new(
-            "promo_revenue",
-            DataType::Decimal128(38, 2),
-            true,
-        )]),
+        14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, 
true)]),
 
         15 => Schema::new(vec![
             Field::new("s_suppkey", DataType::Int64, true),
@@ -250,11 +249,7 @@ pub fn get_answer_schema(n: usize) -> Schema {
             Field::new("supplier_cnt", DataType::Int64, true),
         ]),
 
-        17 => Schema::new(vec![Field::new(
-            "avg_yearly",
-            DataType::Decimal128(38, 2),
-            true,
-        )]),
+        17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, 
true)]),
 
         18 => Schema::new(vec![
             Field::new("c_name", DataType::Utf8, true),
@@ -389,14 +384,14 @@ pub async fn convert_tbl(
 
 /// Converts the results into a 2d array of strings, `result[row][column]`
 /// Special cases nulls to NULL for testing
-pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
+pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<ScalarValue>> {
     let mut result = vec![];
     for batch in results {
         for row_index in 0..batch.num_rows() {
             let row_vec = batch
                 .columns()
                 .iter()
-                .map(|column| col_str(column, row_index))
+                .map(|column| col_to_scalar(column, row_index))
                 .collect();
             result.push(row_vec);
         }
@@ -422,13 +417,37 @@ pub fn string_schema(schema: Schema) -> Schema {
     )
 }
 
-/// Specialised String representation
-fn col_str(column: &ArrayRef, row_index: usize) -> String {
+fn col_to_scalar(column: &ArrayRef, row_index: usize) -> ScalarValue {
     if column.is_null(row_index) {
-        return "NULL".to_string();
+        return ScalarValue::Null;
+    }
+    match column.data_type() {
+        DataType::Int32 => {
+            let array = column.as_any().downcast_ref::<Int32Array>().unwrap();
+            ScalarValue::Int32(Some(array.value(row_index)))
+        }
+        DataType::Int64 => {
+            let array = column.as_any().downcast_ref::<Int64Array>().unwrap();
+            ScalarValue::Int64(Some(array.value(row_index)))
+        }
+        DataType::Float64 => {
+            let array = 
column.as_any().downcast_ref::<Float64Array>().unwrap();
+            ScalarValue::Float64(Some(array.value(row_index)))
+        }
+        DataType::Decimal128(p, s) => {
+            let array = 
column.as_any().downcast_ref::<Decimal128Array>().unwrap();
+            ScalarValue::Decimal128(Some(array.value(row_index)), *p, *s)
+        }
+        DataType::Date32 => {
+            let array = column.as_any().downcast_ref::<Date32Array>().unwrap();
+            ScalarValue::Date32(Some(array.value(row_index)))
+        }
+        DataType::Utf8 => {
+            let array = column.as_any().downcast_ref::<StringArray>().unwrap();
+            ScalarValue::Utf8(Some(array.value(row_index).to_string()))
+        }
+        other => panic!("unexpected data type in benchmark: {}", other),
     }
-
-    array_value_to_string(column, row_index).unwrap()
 }
 
 pub async fn transform_actual_result(
@@ -460,7 +479,7 @@ pub async fn transform_actual_result(
                             Expr::Alias(
                                 Box::new(Expr::Cast(Cast::new(
                                     round,
-                                    DataType::Decimal128(38, 2),
+                                    DataType::Decimal128(15, 2),
                                 ))),
                                 Field::name(field).to_string(),
                             )

Reply via email to