This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c40214f3d93 [SPARK-47350][SQL] Add collation support for SplitPart 
string expression
6c40214f3d93 is described below

commit 6c40214f3d93907686ed731caa3d572a9fa93d53
Author: Uros Bojanic <157381213+uros...@users.noreply.github.com>
AuthorDate: Fri Apr 26 19:57:11 2024 +0800

    [SPARK-47350][SQL] Add collation support for SplitPart string expression
    
    ### What changes were proposed in this pull request?
    Introduce collation awareness for string expression: split_part.
    
    ### Why are the changes needed?
    Add collation support for built-in string function in Spark.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, users should now be able to use collated strings within arguments for 
built-in string function: split_part.
    
    ### How was this patch tested?
    Unit collation support tests and e2e sql tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #46158 from uros-db/SPARK-47350.
    
    Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/util/CollationSupport.java  | 58 +++++++++++++++
 .../spark/unsafe/types/CollationSupportSuite.java  | 86 ++++++++++++++++++++++
 .../catalyst/expressions/stringExpressions.scala   | 15 ++--
 .../sql/CollationStringExpressionsSuite.scala      | 17 +++++
 4 files changed, 170 insertions(+), 6 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
index 70a3f5bd6136..0c03faa0d23a 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java
@@ -23,6 +23,8 @@ import com.ibm.icu.util.ULocale;
 
 import org.apache.spark.unsafe.types.UTF8String;
 
+import java.util.ArrayList;
+import java.util.List;
 import java.util.regex.Pattern;
 
 /**
@@ -36,6 +38,62 @@ public final class CollationSupport {
    * Collation-aware string expressions.
    */
 
+  public static class StringSplitSQL {
+    public static UTF8String[] exec(final UTF8String s, final UTF8String d, 
final int collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      if (collation.supportsBinaryEquality) {
+        return execBinary(s, d);
+      } else if (collation.supportsLowercaseEquality) {
+        return execLowercase(s, d);
+      } else {
+        return execICU(s, d, collationId);
+      }
+    }
+    public static String genCode(final String s, final String d, final int 
collationId) {
+      CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
+      String expr = "CollationSupport.StringSplitSQL.exec";
+      if (collation.supportsBinaryEquality) {
+        return String.format(expr + "Binary(%s, %s)", s, d);
+      } else if (collation.supportsLowercaseEquality) {
+        return String.format(expr + "Lowercase(%s, %s)", s, d);
+      } else {
+        return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId);
+      }
+    }
+    public static UTF8String[] execBinary(final UTF8String string, final 
UTF8String delimiter) {
+      return string.splitSQL(delimiter, -1);
+    }
+    public static UTF8String[] execLowercase(final UTF8String string, final 
UTF8String delimiter) {
+      if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+      if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
+      Pattern pattern = Pattern.compile(Pattern.quote(delimiter.toString()),
+        CollationSupport.lowercaseRegexFlags);
+      String[] splits = pattern.split(string.toString(), -1);
+      UTF8String[] res = new UTF8String[splits.length];
+      for (int i = 0; i < res.length; i++) {
+        res[i] = UTF8String.fromString(splits[i]);
+      }
+      return res;
+    }
+    public static UTF8String[] execICU(final UTF8String string, final 
UTF8String delimiter,
+        final int collationId) {
+      if (delimiter.numBytes() == 0) return new UTF8String[] { string };
+      if (string.numBytes() == 0) return new UTF8String[] { 
UTF8String.EMPTY_UTF8 };
+      List<UTF8String> strings = new ArrayList<>();
+      String target = string.toString(), pattern = delimiter.toString();
+      StringSearch stringSearch = CollationFactory.getStringSearch(target, 
pattern, collationId);
+      int start = 0, end;
+      while ((end = stringSearch.next()) != StringSearch.DONE) {
+        strings.add(UTF8String.fromString(target.substring(start, end)));
+        start = end + stringSearch.getMatchLength();
+      }
+      if (start <= target.length()) {
+        strings.add(UTF8String.fromString(target.substring(start)));
+      }
+      return strings.toArray(new UTF8String[0]);
+    }
+  }
+
   public static class Contains {
     public static boolean exec(final UTF8String l, final UTF8String r, final 
int collationId) {
       CollationFactory.Collation collation = 
CollationFactory.fetchCollation(collationId);
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
index d59bd5c20e67..72edd3e06f9c 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java
@@ -295,6 +295,92 @@ public class CollationSupportSuite {
     assertEndsWith("The KKelvin", "KKelvin,", "UTF8_BINARY_LCASE", false);
   }
 
+  private void assertStringSplitSQL(String str, String delimiter, String 
collationName,
+      UTF8String[] expected) throws SparkException {
+    UTF8String s = UTF8String.fromString(str);
+    UTF8String d = UTF8String.fromString(delimiter);
+    int collationId = CollationFactory.collationNameToId(collationName);
+    assertArrayEquals(expected, CollationSupport.StringSplitSQL.exec(s, d, 
collationId));
+  }
+
+  @Test
+  public void testStringSplitSQL() throws SparkException {
+    // Possible splits
+    var empty_match = new UTF8String[] { UTF8String.fromString("") };
+    var array_abc = new UTF8String[] { UTF8String.fromString("abc") };
+    var array_1a2 = new UTF8String[] { UTF8String.fromString("1a2") };
+    var array_AaXbB = new UTF8String[] { UTF8String.fromString("AaXbB") };
+    var array_aBcDe = new UTF8String[] { UTF8String.fromString("aBcDe") };
+    var array_special = new UTF8String[] { UTF8String.fromString("äb世De") };
+    var array_abcde = new UTF8String[] { UTF8String.fromString("äbćδe") };
+    var full_match = new UTF8String[] { UTF8String.fromString(""), 
UTF8String.fromString("") };
+    var array_1_2 = new UTF8String[] { UTF8String.fromString("1"), 
UTF8String.fromString("2") };
+    var array_A_B = new UTF8String[] { UTF8String.fromString("A"), 
UTF8String.fromString("B") };
+    var array_a_e = new UTF8String[] { UTF8String.fromString("ä"), 
UTF8String.fromString("e") };
+    var array_Aa_bB = new UTF8String[] { UTF8String.fromString("Aa"), 
UTF8String.fromString("bB") };
+    // Edge cases
+    assertStringSplitSQL("", "", "UTF8_BINARY", empty_match);
+    assertStringSplitSQL("abc", "", "UTF8_BINARY", array_abc);
+    assertStringSplitSQL("", "abc", "UTF8_BINARY", empty_match);
+    assertStringSplitSQL("", "", "UNICODE", empty_match);
+    assertStringSplitSQL("abc", "", "UNICODE", array_abc);
+    assertStringSplitSQL("", "abc", "UNICODE", empty_match);
+    assertStringSplitSQL("", "", "UTF8_BINARY_LCASE", empty_match);
+    assertStringSplitSQL("abc", "", "UTF8_BINARY_LCASE", array_abc);
+    assertStringSplitSQL("", "abc", "UTF8_BINARY_LCASE", empty_match);
+    assertStringSplitSQL("", "", "UNICODE_CI", empty_match);
+    assertStringSplitSQL("abc", "", "UNICODE_CI", array_abc);
+    assertStringSplitSQL("", "abc", "UNICODE_CI", empty_match);
+    // Basic tests
+    assertStringSplitSQL("1a2", "a", "UTF8_BINARY", array_1_2);
+    assertStringSplitSQL("1a2", "A", "UTF8_BINARY", array_1a2);
+    assertStringSplitSQL("1a2", "b", "UTF8_BINARY", array_1a2);
+    assertStringSplitSQL("1a2", "1a2", "UNICODE", full_match);
+    assertStringSplitSQL("1a2", "1A2", "UNICODE", array_1a2);
+    assertStringSplitSQL("1a2", "3b4", "UNICODE", array_1a2);
+    assertStringSplitSQL("1a2", "A", "UTF8_BINARY_LCASE", array_1_2);
+    assertStringSplitSQL("1a2", "1A2", "UTF8_BINARY_LCASE", full_match);
+    assertStringSplitSQL("1a2", "X", "UTF8_BINARY_LCASE", array_1a2);
+    assertStringSplitSQL("1a2", "a", "UNICODE_CI", array_1_2);
+    assertStringSplitSQL("1a2", "A", "UNICODE_CI", array_1_2);
+    assertStringSplitSQL("1a2", "1A2", "UNICODE_CI", full_match);
+    assertStringSplitSQL("1a2", "123", "UNICODE_CI", array_1a2);
+    // Case variation
+    assertStringSplitSQL("AaXbB", "x", "UTF8_BINARY", array_AaXbB);
+    assertStringSplitSQL("AaXbB", "X", "UTF8_BINARY", array_Aa_bB);
+    assertStringSplitSQL("AaXbB", "axb", "UNICODE", array_AaXbB);
+    assertStringSplitSQL("AaXbB", "aXb", "UNICODE", array_A_B);
+    assertStringSplitSQL("AaXbB", "axb", "UTF8_BINARY_LCASE", array_A_B);
+    assertStringSplitSQL("AaXbB", "AXB", "UTF8_BINARY_LCASE", array_A_B);
+    assertStringSplitSQL("AaXbB", "axb", "UNICODE_CI", array_A_B);
+    assertStringSplitSQL("AaXbB", "AxB", "UNICODE_CI", array_A_B);
+    // Accent variation
+    assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "BćD", "UTF8_BINARY", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "abćde", "UNICODE", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "aBćDe", "UNICODE", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "bćd", "UTF8_BINARY_LCASE", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "BĆD", "UTF8_BINARY_LCASE", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "abćde", "UNICODE_CI", array_aBcDe);
+    assertStringSplitSQL("aBcDe", "AbĆdE", "UNICODE_CI", array_aBcDe);
+    // Variable byte length characters
+    assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY", array_a_e);
+    assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY", array_special);
+    assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY", array_a_e);
+    assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY", array_abcde);
+    assertStringSplitSQL("äb世De", "äb世De", "UNICODE", full_match);
+    assertStringSplitSQL("äb世De", "äB世de", "UNICODE", array_special);
+    assertStringSplitSQL("äbćδe", "äbćδe", "UNICODE", full_match);
+    assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE", array_abcde);
+    assertStringSplitSQL("äb世De", "b世D", "UTF8_BINARY_LCASE", array_a_e);
+    assertStringSplitSQL("äb世De", "B世d", "UTF8_BINARY_LCASE", array_a_e);
+    assertStringSplitSQL("äbćδe", "bćδ", "UTF8_BINARY_LCASE", array_a_e);
+    assertStringSplitSQL("äbćδe", "BcΔ", "UTF8_BINARY_LCASE", array_abcde);
+    assertStringSplitSQL("äb世De", "ab世De", "UNICODE_CI", array_special);
+    assertStringSplitSQL("äb世De", "AB世dE", "UNICODE_CI", array_special);
+    assertStringSplitSQL("äbćδe", "ÄbćδE", "UNICODE_CI", full_match);
+    assertStringSplitSQL("äbćδe", "ÄBcΔÉ", "UNICODE_CI", array_abcde);
+  }
 
   private void assertUpper(String target, String collationName, String 
expected)
           throws SparkException {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index fd4fc7a54229..612082c56096 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -3187,13 +3187,14 @@ case class Sentences(
 case class StringSplitSQL(
     str: Expression,
     delimiter: Expression) extends BinaryExpression with NullIntolerant {
-  override def dataType: DataType = ArrayType(StringType, containsNull = false)
+  override def dataType: DataType = ArrayType(str.dataType, containsNull = 
false)
+  final lazy val collationId: Int = 
left.dataType.asInstanceOf[StringType].collationId
   override def left: Expression = str
   override def right: Expression = delimiter
 
   override def nullSafeEval(string: Any, delimiter: Any): Any = {
-    val strings = string.asInstanceOf[UTF8String].splitSQL(
-      delimiter.asInstanceOf[UTF8String], -1);
+    val strings = 
CollationSupport.StringSplitSQL.exec(string.asInstanceOf[UTF8String],
+      delimiter.asInstanceOf[UTF8String], collationId)
     new GenericArrayData(strings.asInstanceOf[Array[Any]])
   }
 
@@ -3201,7 +3202,8 @@ case class StringSplitSQL(
     val arrayClass = classOf[GenericArrayData].getName
     nullSafeCodeGen(ctx, ev, (str, delimiter) => {
       // Array in java is covariant, so we don't need to cast UTF8String[] to 
Object[].
-      s"${ev.value} = new $arrayClass($str.splitSQL($delimiter,-1));"
+      s"${ev.value} = new $arrayClass(" +
+        s"${CollationSupport.StringSplitSQL.genCode(str, delimiter, 
collationId)});"
     })
   }
 
@@ -3239,10 +3241,11 @@ case class SplitPart (
     partNum: Expression)
   extends RuntimeReplaceable with ImplicitCastInputTypes {
   override lazy val replacement: Expression =
-    ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", 
StringType)),
+    ElementAt(StringSplitSQL(str, delimiter), partNum, Some(Literal.create("", 
str.dataType)),
       false)
   override def nodeName: String = "split_part"
-  override def inputTypes: Seq[DataType] = Seq(StringType, StringType, 
IntegerType)
+  override def inputTypes: Seq[AbstractDataType] =
+    Seq(StringTypeAnyCollation, StringTypeAnyCollation, IntegerType)
   def children: Seq[Expression] = Seq(str, delimiter, partNum)
   protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): 
Expression = {
     copy(str = newChildren.apply(0), delimiter = newChildren.apply(1),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
index 9c207df95dad..2b6761475a43 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationStringExpressionsSuite.scala
@@ -88,6 +88,23 @@ class CollationStringExpressionsSuite
     assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT")
   }
 
+  test("Support SplitPart string expression with collation") {
+    // Supported collations
+    case class SplitPartTestCase[R](s: String, d: String, p: Int, c: String, 
result: R)
+    val testCases = Seq(
+      SplitPartTestCase("1a2", "a", 2, "UTF8_BINARY", "2"),
+      SplitPartTestCase("1a2", "a", 2, "UNICODE", "2"),
+      SplitPartTestCase("1a2", "A", 2, "UTF8_BINARY_LCASE", "2"),
+      SplitPartTestCase("1a2", "A", 2, "UNICODE_CI", "2")
+    )
+    testCases.foreach(t => {
+      val query = s"SELECT 
split_part(collate('${t.s}','${t.c}'),collate('${t.d}','${t.c}'),${t.p})"
+      // Result & data type
+      checkAnswer(sql(query), Row(t.result))
+      assert(sql(query).schema.fields.head.dataType.sameType(StringType(t.c)))
+    })
+  }
+
   test("Support Contains string expression with collation") {
     // Supported collations
     case class ContainsTestCase[R](l: String, r: String, c: String, result: R)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to