This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 5537572820 fix: `substr_index` not handling negative occurrence 
correctly (#9475)
5537572820 is described below

commit 5537572820977b38719e2253f601a159deef5bc6
Author: Jonah Gao <[email protected]>
AuthorDate: Sat Mar 9 19:18:26 2024 +0800

    fix: `substr_index` not handling negative occurrence correctly (#9475)
    
    * fix: `substr_index` not handling negative occurrence correctly
    
    * format test
    
    * add test
    
    * add more tests
---
 .../physical-expr/src/unicode_expressions.rs       |  54 +++---
 datafusion/sqllogictest/test_files/functions.slt   | 190 +++++++++++++--------
 2 files changed, 141 insertions(+), 103 deletions(-)

diff --git a/datafusion/physical-expr/src/unicode_expressions.rs 
b/datafusion/physical-expr/src/unicode_expressions.rs
index aa6a84119c..8ec9e062d9 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -481,40 +481,28 @@ pub fn substr_index<T: OffsetSizeTrait>(args: 
&[ArrayRef]) -> Result<ArrayRef> {
         .zip(count_array.iter())
         .map(|((string, delimiter), n)| match (string, delimiter, n) {
             (Some(string), Some(delimiter), Some(n)) => {
-                let mut res = String::new();
-                match n {
-                    0 => {
-                        "".to_string();
-                    }
-                    _other => {
-                        if n > 0 {
-                            let idx = string
-                                .split(delimiter)
-                                .take(n as usize)
-                                .fold(0, |len, x| len + x.len() + 
delimiter.len())
-                                - delimiter.len();
-                            res.push_str(if idx >= string.len() {
-                                string
-                            } else {
-                                &string[..idx]
-                            });
-                        } else {
-                            let idx = (string.split(delimiter).take((-n) as 
usize).fold(
-                                string.len() as isize,
-                                |len, x| {
-                                    len - x.len() as isize - delimiter.len() 
as isize
-                                },
-                            ) + delimiter.len() as isize)
-                                as usize;
-                            res.push_str(if idx >= string.len() {
-                                string
-                            } else {
-                                &string[idx..]
-                            });
-                        }
-                    }
+                // In MySQL, these cases will return an empty string.
+                if n == 0 || string.is_empty() || delimiter.is_empty() {
+                    return Some(String::new());
+                }
+
+                let splitted: Box<dyn Iterator<Item = _>> = if n > 0 {
+                    Box::new(string.split(delimiter))
+                } else {
+                    Box::new(string.rsplit(delimiter))
+                };
+                let occurrences = 
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
+                // The length of the substring covered by substr_index.
+                let length = splitted
+                    .take(occurrences) // at least 1 element, since n != 0
+                    .map(|s| s.len() + delimiter.len())
+                    .sum::<usize>()
+                    - delimiter.len();
+                if n > 0 {
+                    Some(string[..length].to_owned())
+                } else {
+                    Some(string[string.len() - length..].to_owned())
                 }
-                Some(res)
             }
             _ => None,
         })
diff --git a/datafusion/sqllogictest/test_files/functions.slt 
b/datafusion/sqllogictest/test_files/functions.slt
index 913cfbafb6..96aa3e2752 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -933,80 +933,130 @@ SELECT levenshtein(NULL, NULL)
 ----
 NULL
 
-query T
-SELECT substr_index('www.apache.org', '.', 1)
-----
-www
-
-query T
-SELECT substr_index('www.apache.org', '.', 2)
-----
-www.apache
-
-query T
-SELECT substr_index('www.apache.org', '.', -1)
-----
-org
-
-query T
-SELECT substr_index('www.apache.org', '.', -2)
-----
-apache.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', 1)
-----
-www.ap
-
-query T
-SELECT substr_index('www.apache.org', 'ac', -1)
-----
-he.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', 2)
-----
-www.apache.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', -2)
-----
-www.apache.org
-
-query ?
-SELECT substr_index(NULL, 'ac', 1)
-----
-NULL
-
-query T
-SELECT substr_index('www.apache.org', NULL, 1)
-----
-NULL
-
-query T
-SELECT substr_index('www.apache.org', 'ac', NULL)
-----
-NULL
-
-query T
-SELECT substr_index('', 'ac', 1)
+# Test substring_index using '.' as delimiter
+# This query is compatible with MySQL(8.0.19 or later), convenient for 
comparing results
+query TIT
+SELECT str, n, substring_index(str, '.', n) AS c FROM
+  (VALUES
+    ROW('arrow.apache.org'),
+    ROW('.'),
+    ROW('...')
+  ) AS strings(str),
+  (VALUES
+    ROW(1),
+    ROW(2),
+    ROW(3),
+    ROW(100),
+    ROW(-1),
+    ROW(-2),
+    ROW(-3),
+    ROW(-100)
+  ) AS occurrences(n)
+ORDER BY str DESC, n;
+----
+arrow.apache.org -100 arrow.apache.org
+arrow.apache.org -3 arrow.apache.org
+arrow.apache.org -2 apache.org
+arrow.apache.org -1 org
+arrow.apache.org 1 arrow
+arrow.apache.org 2 arrow.apache
+arrow.apache.org 3 arrow.apache.org
+arrow.apache.org 100 arrow.apache.org
+... -100 ...
+... -3 ..
+... -2 .
+... -1 (empty)
+... 1 (empty)
+... 2 .
+... 3 ..
+... 100 ...
+. -100 .
+. -3 .
+. -2 .
+. -1 (empty)
+. 1 (empty)
+. 2 .
+. 3 .
+. 100 .
+
+# Test substring_index using 'ac' as delimiter
+query TIT
+SELECT str, n, substring_index(str, 'ac', n) AS c FROM
+  (VALUES
+    -- input string does not contain the delimiter
+    ROW('arrow'),
+    -- input string contains the delimiter
+    ROW('arrow.apache.org')
+  ) AS strings(str),
+  (VALUES
+    ROW(1),
+    ROW(2),
+    ROW(-1),
+    ROW(-2)
+  ) AS occurrences(n)
+ORDER BY str DESC, n;
+----
+arrow.apache.org -2 arrow.apache.org
+arrow.apache.org -1 he.org
+arrow.apache.org 1 arrow.ap
+arrow.apache.org 2 arrow.apache.org
+arrow -2 arrow
+arrow -1 arrow
+arrow 1 arrow
+arrow 2 arrow
+
+# Test substring_index with NULL values
+query ?TT?
+SELECT
+  substring_index(NULL, '.', 1),
+  substring_index('arrow.apache.org', NULL, 1),
+  substring_index('arrow.apache.org', '.', NULL),
+  substring_index(NULL, NULL, NULL)
+----
+NULL NULL NULL NULL
+
+# Test substring_index with empty strings
+query TT
+SELECT
+  -- input string is empty
+  substring_index('', '.', 1),
+  -- delimiter is empty
+  substring_index('arrow.apache.org', '', 1)
+----
+(empty) (empty)
+
+# Test substring_index with 0 occurrence
+query T
+SELECT substring_index('arrow.apache.org', 'ac', 0)
 ----
 (empty)
 
-query T
-SELECT substr_index('www.apache.org', '', 1)
-----
-(empty)
+# Test substring_index with large occurrences
+query TT
+SELECT
+  -- i64::MIN
+  substring_index('arrow.apache.org', '.', -9223372036854775808) as c1,
+  -- i64::MAX
+  substring_index('arrow.apache.org', '.', 9223372036854775807) as c2;
+----
+arrow.apache.org arrow.apache.org
+
+# Test substring_index issue 
https://github.com/apache/arrow-datafusion/issues/9472
+query TTT
+SELECT
+  url,
+  substring_index(url, '.', 1) AS subdomain,
+  substring_index(url, '.', -1) AS tld
+FROM
+  (VALUES ROW('docs.apache.com'),
+          ROW('community.influxdata.com'),
+          ROW('arrow.apache.org')
+  ) data(url)
+----
+docs.apache.com docs com
+community.influxdata.com community com
+arrow.apache.org arrow org
 
-query T
-SELECT substr_index('www.apache.org', 'ac', 0)
-----
-(empty)
-
-query ?
-SELECT substr_index(NULL, NULL, NULL)
-----
-NULL
 
 query I
 SELECT find_in_set('b', 'a,b,c,d')

Reply via email to