This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7517430676 fix: `generate_series` and `range` panic on edge cases 
(#9503)
7517430676 is described below

commit 751743067642b6ed8d2050b529881f950dbf9c87
Author: Jonah Gao <[email protected]>
AuthorDate: Sat Mar 9 05:18:06 2024 +0800

    fix: `generate_series` and `range` panic on edge cases (#9503)
    
    * fix: `generate_series` and `range` panic on edge cases
    
    * fix comment
    
    * avoid casting to i128
    
    * add tests
    
    * use target_pointer_width
    
    * remove target_pointer_width
---
 datafusion/functions-array/src/kernels.rs    | 66 +++++++++++++++++++++-------
 datafusion/functions-array/src/udf.rs        |  8 ++--
 datafusion/sqllogictest/test_files/array.slt | 64 +++++++++++++++++++++++++--
 3 files changed, 114 insertions(+), 24 deletions(-)

diff --git a/datafusion/functions-array/src/kernels.rs 
b/datafusion/functions-array/src/kernels.rs
index 70c778f340..c22ddeb43a 100644
--- a/datafusion/functions-array/src/kernels.rs
+++ b/datafusion/functions-array/src/kernels.rs
@@ -31,7 +31,7 @@ use datafusion_common::cast::{
     as_date32_array, as_int64_array, as_interval_mdn_array, 
as_large_list_array,
     as_list_array, as_string_array,
 };
-use datafusion_common::{exec_err, DataFusionError, Result};
+use datafusion_common::{exec_err, not_impl_datafusion_err, DataFusionError, 
Result};
 use std::any::type_name;
 use std::sync::Arc;
 macro_rules! downcast_arg {
@@ -273,7 +273,7 @@ pub(super) fn array_to_string(args: &[ArrayRef]) -> 
Result<ArrayRef> {
 /// gen_range(3) => [0, 1, 2]
 /// gen_range(1, 4) => [1, 2, 3]
 /// gen_range(1, 7, 2) => [1, 3, 5]
-pub fn gen_range(args: &[ArrayRef], include_upper: i64) -> Result<ArrayRef> {
+pub(super) fn gen_range(args: &[ArrayRef], include_upper: bool) -> 
Result<ArrayRef> {
     let (start_array, stop_array, step_array) = match args.len() {
         1 => (None, as_int64_array(&args[0])?, None),
         2 => (
@@ -292,22 +292,27 @@ pub fn gen_range(args: &[ArrayRef], include_upper: i64) 
-> Result<ArrayRef> {
     let mut values = vec![];
     let mut offsets = vec![0];
     for (idx, stop) in stop_array.iter().enumerate() {
-        let mut stop = stop.unwrap_or(0);
+        let stop = stop.unwrap_or(0);
         let start = start_array.as_ref().map(|arr| 
arr.value(idx)).unwrap_or(0);
         let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1);
         if step == 0 {
-            return exec_err!("step can't be 0 for function range(start [, 
stop, step]");
-        }
-        if step < 0 {
-            // Decreasing range
-            stop -= include_upper;
-            values.extend((stop + 1..start + 1).rev().step_by((-step) as 
usize));
-        } else {
-            // Increasing range
-            stop += include_upper;
-            values.extend((start..stop).step_by(step as usize));
+            return exec_err!(
+                "step can't be 0 for function {}(start [, stop, step])",
+                if include_upper {
+                    "generate_series"
+                } else {
+                    "range"
+                }
+            );
         }
-
+        // Below, we utilize `usize` to represent steps.
+        // On 32-bit targets, the absolute value of `i64` may fail to fit into 
`usize`.
+        let step_abs = usize::try_from(step.unsigned_abs()).map_err(|_| {
+            not_impl_datafusion_err!("step {} can't fit into usize", step)
+        })?;
+        values.extend(
+            gen_range_iter(start, stop, step < 0, 
include_upper).step_by(step_abs),
+        );
         offsets.push(values.len() as i32);
     }
     let arr = Arc::new(ListArray::try_new(
@@ -319,6 +324,35 @@ pub fn gen_range(args: &[ArrayRef], include_upper: i64) -> 
Result<ArrayRef> {
     Ok(arr)
 }
 
+/// Returns an iterator of i64 values from start to stop
+fn gen_range_iter(
+    start: i64,
+    stop: i64,
+    decreasing: bool,
+    include_upper: bool,
+) -> Box<dyn Iterator<Item = i64>> {
+    match (decreasing, include_upper) {
+        // Decreasing range, stop is inclusive
+        (true, true) => Box::new((stop..=start).rev()),
+        // Decreasing range, stop is exclusive
+        (true, false) => {
+            if stop == i64::MAX {
+                // start is never greater than stop, and stop is exclusive,
+                // so the decreasing range must be empty.
+                Box::new(std::iter::empty())
+            } else {
+                // Increase the stop value by one to exclude it.
+                // Since stop is not i64::MAX, `stop + 1` will not overflow.
+                Box::new((stop + 1..=start).rev())
+            }
+        }
+        // Increasing range, stop is inclusive
+        (false, true) => Box::new(start..=stop),
+        // Increasing range, stop is exclusive
+        (false, false) => Box::new(start..stop),
+    }
+}
+
 /// Returns the length of each array dimension
 fn compute_array_dims(arr: Option<ArrayRef>) -> 
Result<Option<Vec<Option<u64>>>> {
     let mut value = match arr {
@@ -442,7 +476,7 @@ pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
 }
 pub fn gen_range_date(
     args: &[ArrayRef],
-    include_upper: i32,
+    include_upper: bool,
 ) -> datafusion_common::Result<ArrayRef> {
     if args.len() != 3 {
         return exec_err!("arguments length does not match");
@@ -461,7 +495,7 @@ pub fn gen_range_date(
         let step = step_array.as_ref().map(|arr| arr.value(idx)).unwrap_or(1);
         let (months, days, _) = IntervalMonthDayNanoType::to_parts(step);
         let neg = months < 0 || days < 0;
-        if include_upper == 0 {
+        if !include_upper {
             stop = Date32Type::subtract_month_day_nano(stop, step);
         }
         let mut new_date = start;
diff --git a/datafusion/functions-array/src/udf.rs 
b/datafusion/functions-array/src/udf.rs
index 709a33cc45..6c69553962 100644
--- a/datafusion/functions-array/src/udf.rs
+++ b/datafusion/functions-array/src/udf.rs
@@ -142,10 +142,10 @@ impl ScalarUDFImpl for Range {
         let args = ColumnarValue::values_to_arrays(args)?;
         match args[0].data_type() {
             arrow::datatypes::DataType::Int64 => {
-                crate::kernels::gen_range(&args, 0).map(ColumnarValue::Array)
+                crate::kernels::gen_range(&args, 
false).map(ColumnarValue::Array)
             }
             arrow::datatypes::DataType::Date32 => {
-                crate::kernels::gen_range_date(&args, 
0).map(ColumnarValue::Array)
+                crate::kernels::gen_range_date(&args, 
false).map(ColumnarValue::Array)
             }
             _ => {
                 exec_err!("unsupported type for range")
@@ -212,10 +212,10 @@ impl ScalarUDFImpl for GenSeries {
         let args = ColumnarValue::values_to_arrays(args)?;
         match args[0].data_type() {
             arrow::datatypes::DataType::Int64 => {
-                crate::kernels::gen_range(&args, 1).map(ColumnarValue::Array)
+                crate::kernels::gen_range(&args, 
true).map(ColumnarValue::Array)
             }
             arrow::datatypes::DataType::Date32 => {
-                crate::kernels::gen_range_date(&args, 
1).map(ColumnarValue::Array)
+                crate::kernels::gen_range_date(&args, 
true).map(ColumnarValue::Array)
             }
             _ => {
                 exec_err!("unsupported type for range")
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 68a7a34746..434fe8c959 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -5554,10 +5554,11 @@ from arrays_range;
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] [3, 4, 5, 6, 7, 8, 9] [3, 5, 7, 9]
 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 5, 6, 7, 8, 9, 10, 11, 12] [4, 
7, 10]
 
-query ??????????
+query ???????????
 select range(5),
        range(2, 5),
        range(2, 10, 3),
+       range(10, 2, -3),
        range(1, 5, -1),
        range(1, -5, 1),
        range(1, -5, -1),
@@ -5567,7 +5568,35 @@ select range(5),
        range(DATE '1993-03-01', DATE '1989-04-01', INTERVAL '1' YEAR)
 ;
 ----
-[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [] [] [1, 0, -1, -2, -3, -4] [1992-09-01, 
1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] [1993-02-01, 
1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 1993-01-26, 
1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 1993-01-20, 
1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 1993-01-14, 
1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 1993-01-08, 
1993-01-07, 1993-01-06, 1993-01-05, 1993-01-04, 1993- [...]
+[0, 1, 2, 3, 4] [2, 3, 4] [2, 5, 8] [10, 7, 4] [] [] [1, 0, -1, -2, -3, -4] 
[1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01] 
[1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 1993-01-27, 
1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 1993-01-21, 
1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 1993-01-15, 
1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 1993-01-09, 
1993-01-08, 1993-01-07, 1993-01-06, 1993-01-05, 1993-0 [...]
+
+# Test range with zero step
+query error DataFusion error: Execution error: step can't be 0 for function 
range\(start \[, stop, step\]\)
+select range(1, 1, 0);
+
+# Test range with big steps
+query ????
+select
+  range(-9223372036854775808, -9223372036854775808, -9223372036854775808) as 
c1,
+  range(9223372036854775807, 9223372036854775807, 9223372036854775807) as c2,
+  range(0, -9223372036854775808, -9223372036854775808) as c3,
+  range(0, 9223372036854775807, 9223372036854775807) as c4;
+----
+[] [] [0] [0]
+
+# Test range for other egde cases
+query ????????
+select 
+  range(9223372036854775807, 9223372036854775807, -1) as c1,
+  range(9223372036854775807, 9223372036854775806, -1) as c2,
+  range(9223372036854775807, 9223372036854775807, 1) as c3,
+  range(9223372036854775806, 9223372036854775807, 1) as c4,
+  range(-9223372036854775808, -9223372036854775808, -1) as c5,
+  range(-9223372036854775807, -9223372036854775808, -1) as c6,
+  range(-9223372036854775808, -9223372036854775808, 1) as c7,
+  range(-9223372036854775808, -9223372036854775807, 1) as c8;
+----
+[] [9223372036854775807] [] [9223372036854775806] [] [-9223372036854775807] [] 
[-9223372036854775808]
 
 ## should throw error
 query error
@@ -5589,18 +5618,19 @@ select range(DATE '1993-03-01', DATE '1989-04-01', 
INTERVAL '1' YEAR)
 ----
 []
 
-query ????????
+query ?????????
 select generate_series(5),
        generate_series(2, 5),
        generate_series(2, 10, 3),
        generate_series(1, 5, 1),
        generate_series(5, 1, -1),
+       generate_series(10, 2, -3),
        generate_series(DATE '1992-09-01', DATE '1993-03-01', INTERVAL '1' 
MONTH),
        generate_series(DATE '1993-02-01', DATE '1993-01-01', INTERVAL '-1' 
DAY),
        generate_series(DATE '1989-04-01', DATE '1993-03-01', INTERVAL '1' YEAR)
 ;
 ----
-[0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] 
[1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 
1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 
1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 
1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 
1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 
1993-01-09, 1993-01-08, 1993-01-07, 1993-01-06, 1993-01- [...]
+[0, 1, 2, 3, 4, 5] [2, 3, 4, 5] [2, 5, 8] [1, 2, 3, 4, 5] [5, 4, 3, 2, 1] [10, 
7, 4] [1992-09-01, 1992-10-01, 1992-11-01, 1992-12-01, 1993-01-01, 1993-02-01, 
1993-03-01] [1993-02-01, 1993-01-31, 1993-01-30, 1993-01-29, 1993-01-28, 
1993-01-27, 1993-01-26, 1993-01-25, 1993-01-24, 1993-01-23, 1993-01-22, 
1993-01-21, 1993-01-20, 1993-01-19, 1993-01-18, 1993-01-17, 1993-01-16, 
1993-01-15, 1993-01-14, 1993-01-13, 1993-01-12, 1993-01-11, 1993-01-10, 
1993-01-09, 1993-01-08, 1993-01-07, 1993-01-0 [...]
 
 ## should throw error
 query error
@@ -5623,6 +5653,32 @@ select generate_series(DATE '1993-03-01', DATE 
'1989-04-01', INTERVAL '1' YEAR)
 ----
 []
 
+# Test generate_series with zero step
+query error DataFusion error: Execution error: step can't be 0 for function 
generate_series\(start \[, stop, step\]\)
+select generate_series(1, 1, 0);
+
+# Test generate_series with big steps
+query ????
+select
+  generate_series(-9223372036854775808, -9223372036854775808, 
-9223372036854775808) as c1,
+  generate_series(9223372036854775807, 9223372036854775807, 
9223372036854775807) as c2,
+  generate_series(0, -9223372036854775808, -9223372036854775808) as c3,
+  generate_series(0, 9223372036854775807, 9223372036854775807) as c4;
+----
+[-9223372036854775808] [9223372036854775807] [0, -9223372036854775808] [0, 
9223372036854775807]
+
+
+# Test generate_series for other egde cases
+query ????
+select 
+  generate_series(9223372036854775807, 9223372036854775807, -1) as c1,
+  generate_series(9223372036854775807, 9223372036854775807, 1) as c2,
+  generate_series(-9223372036854775808, -9223372036854775808, -1) as c3,
+  generate_series(-9223372036854775808, -9223372036854775808, 1) as c4;
+----
+[9223372036854775807] [9223372036854775807] [-9223372036854775808] 
[-9223372036854775808]
+
+
 ## array_except
 
 statement ok

Reply via email to