This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 5537572820 fix: `substr_index` not handling negative occurrence
correctly (#9475)
5537572820 is described below
commit 5537572820977b38719e2253f601a159deef5bc6
Author: Jonah Gao <[email protected]>
AuthorDate: Sat Mar 9 19:18:26 2024 +0800
fix: `substr_index` not handling negative occurrence correctly (#9475)
* fix: `substr_index` not handling negative occurrence correctly
* format test
* add test
* add more tests
---
.../physical-expr/src/unicode_expressions.rs | 54 +++---
datafusion/sqllogictest/test_files/functions.slt | 190 +++++++++++++--------
2 files changed, 141 insertions(+), 103 deletions(-)
diff --git a/datafusion/physical-expr/src/unicode_expressions.rs
b/datafusion/physical-expr/src/unicode_expressions.rs
index aa6a84119c..8ec9e062d9 100644
--- a/datafusion/physical-expr/src/unicode_expressions.rs
+++ b/datafusion/physical-expr/src/unicode_expressions.rs
@@ -481,40 +481,28 @@ pub fn substr_index<T: OffsetSizeTrait>(args:
&[ArrayRef]) -> Result<ArrayRef> {
.zip(count_array.iter())
.map(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
- let mut res = String::new();
- match n {
- 0 => {
- "".to_string();
- }
- _other => {
- if n > 0 {
- let idx = string
- .split(delimiter)
- .take(n as usize)
- .fold(0, |len, x| len + x.len() +
delimiter.len())
- - delimiter.len();
- res.push_str(if idx >= string.len() {
- string
- } else {
- &string[..idx]
- });
- } else {
- let idx = (string.split(delimiter).take((-n) as
usize).fold(
- string.len() as isize,
- |len, x| {
- len - x.len() as isize - delimiter.len()
as isize
- },
- ) + delimiter.len() as isize)
- as usize;
- res.push_str(if idx >= string.len() {
- string
- } else {
- &string[idx..]
- });
- }
- }
+ // In MySQL, these cases will return an empty string.
+ if n == 0 || string.is_empty() || delimiter.is_empty() {
+ return Some(String::new());
+ }
+
+ let splitted: Box<dyn Iterator<Item = _>> = if n > 0 {
+ Box::new(string.split(delimiter))
+ } else {
+ Box::new(string.rsplit(delimiter))
+ };
+ let occurrences =
usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
+ // The length of the substring covered by substr_index.
+ let length = splitted
+ .take(occurrences) // at least 1 element, since n != 0
+ .map(|s| s.len() + delimiter.len())
+ .sum::<usize>()
+ - delimiter.len();
+ if n > 0 {
+ Some(string[..length].to_owned())
+ } else {
+ Some(string[string.len() - length..].to_owned())
}
- Some(res)
}
_ => None,
})
diff --git a/datafusion/sqllogictest/test_files/functions.slt
b/datafusion/sqllogictest/test_files/functions.slt
index 913cfbafb6..96aa3e2752 100644
--- a/datafusion/sqllogictest/test_files/functions.slt
+++ b/datafusion/sqllogictest/test_files/functions.slt
@@ -933,80 +933,130 @@ SELECT levenshtein(NULL, NULL)
----
NULL
-query T
-SELECT substr_index('www.apache.org', '.', 1)
-----
-www
-
-query T
-SELECT substr_index('www.apache.org', '.', 2)
-----
-www.apache
-
-query T
-SELECT substr_index('www.apache.org', '.', -1)
-----
-org
-
-query T
-SELECT substr_index('www.apache.org', '.', -2)
-----
-apache.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', 1)
-----
-www.ap
-
-query T
-SELECT substr_index('www.apache.org', 'ac', -1)
-----
-he.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', 2)
-----
-www.apache.org
-
-query T
-SELECT substr_index('www.apache.org', 'ac', -2)
-----
-www.apache.org
-
-query ?
-SELECT substr_index(NULL, 'ac', 1)
-----
-NULL
-
-query T
-SELECT substr_index('www.apache.org', NULL, 1)
-----
-NULL
-
-query T
-SELECT substr_index('www.apache.org', 'ac', NULL)
-----
-NULL
-
-query T
-SELECT substr_index('', 'ac', 1)
+# Test substring_index using '.' as delimiter
+# This query is compatible with MySQL(8.0.19 or later), convenient for
comparing results
+query TIT
+SELECT str, n, substring_index(str, '.', n) AS c FROM
+ (VALUES
+ ROW('arrow.apache.org'),
+ ROW('.'),
+ ROW('...')
+ ) AS strings(str),
+ (VALUES
+ ROW(1),
+ ROW(2),
+ ROW(3),
+ ROW(100),
+ ROW(-1),
+ ROW(-2),
+ ROW(-3),
+ ROW(-100)
+ ) AS occurrences(n)
+ORDER BY str DESC, n;
+----
+arrow.apache.org -100 arrow.apache.org
+arrow.apache.org -3 arrow.apache.org
+arrow.apache.org -2 apache.org
+arrow.apache.org -1 org
+arrow.apache.org 1 arrow
+arrow.apache.org 2 arrow.apache
+arrow.apache.org 3 arrow.apache.org
+arrow.apache.org 100 arrow.apache.org
+... -100 ...
+... -3 ..
+... -2 .
+... -1 (empty)
+... 1 (empty)
+... 2 .
+... 3 ..
+... 100 ...
+. -100 .
+. -3 .
+. -2 .
+. -1 (empty)
+. 1 (empty)
+. 2 .
+. 3 .
+. 100 .
+
+# Test substring_index using 'ac' as delimiter
+query TIT
+SELECT str, n, substring_index(str, 'ac', n) AS c FROM
+ (VALUES
+ -- input string does not contain the delimiter
+ ROW('arrow'),
+ -- input string contains the delimiter
+ ROW('arrow.apache.org')
+ ) AS strings(str),
+ (VALUES
+ ROW(1),
+ ROW(2),
+ ROW(-1),
+ ROW(-2)
+ ) AS occurrences(n)
+ORDER BY str DESC, n;
+----
+arrow.apache.org -2 arrow.apache.org
+arrow.apache.org -1 he.org
+arrow.apache.org 1 arrow.ap
+arrow.apache.org 2 arrow.apache.org
+arrow -2 arrow
+arrow -1 arrow
+arrow 1 arrow
+arrow 2 arrow
+
+# Test substring_index with NULL values
+query ?TT?
+SELECT
+ substring_index(NULL, '.', 1),
+ substring_index('arrow.apache.org', NULL, 1),
+ substring_index('arrow.apache.org', '.', NULL),
+ substring_index(NULL, NULL, NULL)
+----
+NULL NULL NULL NULL
+
+# Test substring_index with empty strings
+query TT
+SELECT
+ -- input string is empty
+ substring_index('', '.', 1),
+ -- delimiter is empty
+ substring_index('arrow.apache.org', '', 1)
+----
+(empty) (empty)
+
+# Test substring_index with 0 occurrence
+query T
+SELECT substring_index('arrow.apache.org', 'ac', 0)
----
(empty)
-query T
-SELECT substr_index('www.apache.org', '', 1)
-----
-(empty)
+# Test substring_index with large occurrences
+query TT
+SELECT
+ -- i64::MIN
+ substring_index('arrow.apache.org', '.', -9223372036854775808) as c1,
+ -- i64::MAX
+ substring_index('arrow.apache.org', '.', 9223372036854775807) as c2;
+----
+arrow.apache.org arrow.apache.org
+
+# Test substring_index issue
https://github.com/apache/arrow-datafusion/issues/9472
+query TTT
+SELECT
+ url,
+ substring_index(url, '.', 1) AS subdomain,
+ substring_index(url, '.', -1) AS tld
+FROM
+ (VALUES ROW('docs.apache.com'),
+ ROW('community.influxdata.com'),
+ ROW('arrow.apache.org')
+ ) data(url)
+----
+docs.apache.com docs com
+community.influxdata.com community com
+arrow.apache.org arrow org
-query T
-SELECT substr_index('www.apache.org', 'ac', 0)
-----
-(empty)
-
-query ?
-SELECT substr_index(NULL, NULL, NULL)
-----
-NULL
query I
SELECT find_in_set('b', 'a,b,c,d')