This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 861cca3da4c4 [SPARK-46832][SQL] Introducing Collate and Collation 
expressions
861cca3da4c4 is described below

commit 861cca3da4c446761ccff007c89b214a691b0a72
Author: Aleksandar Tomic <aleksandar.to...@databricks.com>
AuthorDate: Wed Feb 14 19:14:50 2024 +0300

    [SPARK-46832][SQL] Introducing Collate and Collation expressions
    
    ### What changes were proposed in this pull request?
    
    This PR adds E2E support for `collate` and `collation` expressions.
    Following changes were made to get us there:
    1) Set the right ordering for `PhysicalStringType` based on `collationId`.
    2) UTF8String is now just a data holder class - it no longer implements 
`Comparable` interface. All comparisons must be done through `CollationFactory`.
    3) `collate` and `collation` expressions are added. Special syntax for 
`collate` is enabled - `'hello world' COLLATE 'target_collation'
    4) First set of tests is added that covers both core expression and E2E 
collation tests.
    
    ### Why are the changes needed?
    
    This PR is part of larger collation track. For more details please refer to 
design doc attached in parent JIRA ticket.
    
    ### Does this PR introduce _any_ user-facing change?
    
    This test adds two new expressions and opens up new syntax.
    
    ### How was this patch tested?
    
    Basic tests are added. In follow up PRs we will add support for more 
advanced operators and keep adding tests alongside new feature support.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Yes.
    
    Closes #45064 from dbatomic/stringtype_compare.
    
    Lead-authored-by: Aleksandar Tomic <aleksandar.to...@databricks.com>
    Co-authored-by: Stefan Kandic <stefan.kan...@databricks.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 .../spark/sql/catalyst/util/CollationFactory.java  |   5 +-
 .../org/apache/spark/unsafe/types/UTF8String.java  |  59 ++++++-
 .../apache/spark/unsafe/types/UTF8StringSuite.java |  24 +--
 .../types/UTF8StringPropertyCheckSuite.scala       |   2 +-
 .../spark/sql/catalyst/parser/SqlBaseParser.g4     |   1 +
 .../spark/sql/catalyst/encoders/RowEncoder.scala   |   2 +-
 .../org/apache/spark/sql/types/StringType.scala    |  23 ++-
 .../sql/catalyst/CatalystTypeConverters.scala      |   2 +-
 .../sql/catalyst/analysis/FunctionRegistry.scala   |   2 +
 .../spark/sql/catalyst/encoders/EncoderUtils.scala |   2 +-
 .../sql/catalyst/expressions/ToStringBase.scala    |   4 +-
 .../aggregate/BloomFilterAggregate.scala           |   4 +-
 .../expressions/codegen/CodeGenerator.scala        |  13 +-
 .../expressions/collationExpressions.scala         | 100 ++++++++++++
 .../spark/sql/catalyst/parser/AstBuilder.scala     |   8 +
 .../sql/catalyst/types/PhysicalDataType.scala      |   4 +-
 .../catalyst/expressions/CodeGenerationSuite.scala |   9 +-
 .../expressions/CollationExpressionSuite.scala     |  77 +++++++++
 .../apache/spark/sql/execution/HiveResult.scala    |   2 +-
 .../spark/sql/execution/columnar/ColumnStats.scala |   4 +-
 .../sql-functions/sql-expression-schema.md         |   2 +
 .../org/apache/spark/sql/CollationSuite.scala      | 177 +++++++++++++++++++++
 .../sql/expressions/ExpressionInfoSuite.scala      |   5 +-
 23 files changed, 484 insertions(+), 47 deletions(-)

diff --git 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
index 018fb6cbeb9f..83cac849e848 100644
--- 
a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
+++ 
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java
@@ -112,7 +112,7 @@ public final class CollationFactory {
     collationTable[0] = new Collation(
       "UCS_BASIC",
       null,
-      UTF8String::compareTo,
+      UTF8String::binaryCompare,
       "1.0",
       s -> (long)s.hashCode(),
       true);
@@ -122,7 +122,7 @@ public final class CollationFactory {
     collationTable[1] = new Collation(
       "UCS_BASIC_LCASE",
       null,
-      Comparator.comparing(UTF8String::toLowerCase),
+      (s1, s2) -> s1.toLowerCase().binaryCompare(s2.toLowerCase()),
       "1.0",
       (s) -> (long)s.toLowerCase().hashCode(),
       false);
@@ -132,7 +132,6 @@ public final class CollationFactory {
       "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true);
     collationTable[2].collator.setStrength(Collator.TERTIARY);
 
-
     // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary 
strength).
     collationTable[3] = new Collation(
       "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false);
diff --git 
a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java 
b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index 21c31c954ba3..bb794446472f 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -30,12 +30,14 @@ import com.esotericsoftware.kryo.KryoSerializable;
 import com.esotericsoftware.kryo.io.Input;
 import com.esotericsoftware.kryo.io.Output;
 
+import org.apache.spark.sql.catalyst.util.CollationFactory;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.UTF8StringBuilder;
 import org.apache.spark.unsafe.array.ByteArrayMethods;
 import org.apache.spark.unsafe.hash.Murmur3_x86_32;
 
 import static org.apache.spark.unsafe.Platform.*;
+import org.apache.spark.util.SparkEnvUtils$;
 
 
 /**
@@ -1388,28 +1390,71 @@ public final class UTF8String implements 
Comparable<UTF8String>, Externalizable,
     return fromBytes(bytes);
   }
 
+  /**
+   * Implementation of Comparable interface. This method is kept for backwards 
compatibility.
+   * It should not be used in spark code base, given that string comparison 
requires passing
+   * collation id. Either explicitly use `binaryCompare` or use 
`semanticCompare`.
+   */
   @Override
   public int compareTo(@Nonnull final UTF8String other) {
+    if (SparkEnvUtils$.MODULE$.isTesting()) {
+      throw new UnsupportedOperationException(
+        "compareTo should not be used in spark code base. Use binaryCompare or 
semanticCompare.");
+    } else {
+      return binaryCompare(other);
+    }
+  }
+
+  /**
+   * Binary comparison of two UTF8String. Can only be used for default 
UCS_BASIC collation.
+   */
+  public int binaryCompare(final UTF8String other) {
     return ByteArray.compareBinary(
-        base, offset, numBytes, other.base, other.offset, other.numBytes);
+      base, offset, numBytes, other.base, other.offset, other.numBytes);
   }
 
-  public int compare(final UTF8String other) {
-    return compareTo(other);
+  /**
+   * Collation-aware comparison of two UTF8String. The collation to use is 
specified by the
+   * `collationId` parameter.
+   */
+  public int semanticCompare(final UTF8String other, int collationId) {
+    return 
CollationFactory.fetchCollation(collationId).comparator.compare(this, other);
   }
 
+  /**
+   * Binary equality check of two UTF8String. Note that binary equality is not 
the same as
+   * equality under given collation. E.g. if string is collated in 
case-insensitive two strings
+   * are considered equal even if they are different in binary comparison.
+   */
   @Override
   public boolean equals(final Object other) {
     if (other instanceof UTF8String o) {
-      if (numBytes != o.numBytes) {
-        return false;
-      }
-      return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, 
numBytes);
+      return binaryEquals(o);
     } else {
       return false;
     }
   }
 
+  /**
+   * Binary equality check of two UTF8String. Note that binary equality is not 
the same as
+   * equality under given collation. E.g. if string is collated in 
case-insensitive two strings
+   * are considered equal even if they are different in binary comparison.
+   */
+  public boolean binaryEquals(final UTF8String other) {
+    if (numBytes != other.numBytes) {
+      return false;
+    }
+
+    return ByteArrayMethods.arrayEquals(base, offset, other.base, 
other.offset, numBytes);
+  }
+
+  /**
+   * Collation-aware equality comparison of two UTF8String.
+   */
+  public boolean semanticEquals(final UTF8String other, int collationId) {
+    return 
CollationFactory.fetchCollation(collationId).equalsFunction.apply(this, other);
+  }
+
   /**
    * Levenshtein distance is a metric for measuring the distance of two 
strings. The distance is
    * defined by the minimum number of single-character edits (i.e. insertions, 
deletions or
diff --git 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index bca45b0764c9..594b96944934 100644
--- 
a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ 
b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -47,7 +47,7 @@ public class UTF8StringSuite {
 
     assertEquals(s1.hashCode(), s2.hashCode());
 
-    assertEquals(0, s1.compareTo(s2));
+    assertEquals(0, s1.binaryCompare(s2));
 
     assertTrue(s1.contains(s2));
     assertTrue(s2.contains(s1));
@@ -93,18 +93,18 @@ public class UTF8StringSuite {
   }
 
   @Test
-  public void compareTo() {
-    assertTrue(fromString("").compareTo(fromString("a")) < 0);
-    assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0);
-    assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0);
-    assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 
0);
-    assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0);
-    assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0);
-    assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0);
+  public void binaryCompareTo() {
+    assertTrue(fromString("").binaryCompare(fromString("a")) < 0);
+    assertTrue(fromString("abc").binaryCompare(fromString("ABC")) > 0);
+    assertTrue(fromString("abc0").binaryCompare(fromString("abc")) > 0);
+    assertTrue(fromString("abcabcabc").binaryCompare(fromString("abcabcabc")) 
== 0);
+    assertTrue(fromString("aBcabcabc").binaryCompare(fromString("Abcabcabc")) 
> 0);
+    assertTrue(fromString("Abcabcabc").binaryCompare(fromString("abcabcabC")) 
< 0);
+    assertTrue(fromString("abcabcabc").binaryCompare(fromString("abcabcabC")) 
> 0);
 
-    assertTrue(fromString("abc").compareTo(fromString("世界")) < 0);
-    assertTrue(fromString("你好").compareTo(fromString("世界")) > 0);
-    assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0);
+    assertTrue(fromString("abc").binaryCompare(fromString("世界")) < 0);
+    assertTrue(fromString("你好").binaryCompare(fromString("世界")) > 0);
+    assertTrue(fromString("你好123").binaryCompare(fromString("你好122")) > 0);
   }
 
   protected static void testUpperandLower(String upper, String lower) {
diff --git 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
index 75c56451592e..3f02d7261112 100644
--- 
a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
+++ 
b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala
@@ -81,7 +81,7 @@ class UTF8StringPropertyCheckSuite extends AnyFunSuite with 
ScalaCheckDrivenProp
   test("compare") {
     forAll { (s1: String, s2: String) =>
       assert(Math.signum {
-        toUTF8(s1).compareTo(toUTF8(s2)).toFloat
+        toUTF8(s1).binaryCompare(toUTF8(s2)).toFloat
       } === Math.signum(s1.compareTo(s2).toFloat))
     }
   }
diff --git 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index 737d5196e7c4..1109e4a7bdfc 100644
--- 
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++ 
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -989,6 +989,7 @@ primaryExpression
     | CASE whenClause+ (ELSE elseExpression=expression)? END                   
                #searchedCase
     | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END  
                #simpleCase
     | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN     
                #cast
+    | primaryExpression COLLATE stringLit                                      
                #collate
     | primaryExpression DOUBLE_COLON dataType                                  
                #castByColon
     | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA 
argument+=namedExpression)*)? RIGHT_PAREN #struct
     | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN                  
                #first
diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index a201da9c95c9..16ac283eccb1 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -81,7 +81,7 @@ object RowEncoder {
     case DoubleType => BoxedDoubleEncoder
     case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true)
     case BinaryType => BinaryEncoder
-    case StringType => StringEncoder
+    case _: StringType => StringEncoder
     case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => 
InstantEncoder(lenient)
     case TimestampType => TimestampEncoder(lenient)
     case TimestampNTZType => LocalDateTimeEncoder
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
index bd2ff8475741..501d86433847 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.types
 
 import org.apache.spark.annotation.Stable
+import org.apache.spark.sql.catalyst.util.CollationFactory
 
 /**
  * The data type representing `String` values. Please use the singleton 
`DataTypes.StringType`.
@@ -26,7 +27,27 @@ import org.apache.spark.annotation.Stable
  * @param collationId The id of collation for this StringType.
  */
 @Stable
-class StringType private(val collationId: Int) extends AtomicType {
+class StringType private(val collationId: Int) extends AtomicType with 
Serializable {
+  /**
+   * Returns whether assigned collation is the default spark collation 
(UCS_BASIC).
+   */
+  def isDefaultCollation: Boolean = collationId == 
StringType.DEFAULT_COLLATION_ID
+
+  /**
+   * Type name that is shown to the customer.
+   * If this is an UCS_BASIC collation output is `string` due to backwards 
compatibility.
+   */
+  override def typeName: String =
+    if (isDefaultCollation) "string"
+    else 
s"string(${CollationFactory.fetchCollation(collationId).collationName})"
+
+  override def equals(obj: Any): Boolean =
+    obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId 
== collationId
+
+  override def hashCode(): Int = collationId.hashCode()
+
+  override private[sql] def acceptsType(other: DataType): Boolean = 
other.isInstanceOf[StringType]
+
   /**
    * The default size of a value of the StringType is 20 bytes.
    */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 32e0c5884ebe..2b2a186f76d9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -66,7 +66,7 @@ object CatalystTypeConverters {
       case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
       case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
       case structType: StructType => StructConverter(structType)
-      case StringType => StringConverter
+      case _: StringType => StringConverter
       case DateType if SQLConf.get.datetimeJava8ApiEnabled => 
LocalDateConverter
       case DateType => DateConverter
       case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => 
InstantConverter
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 391ff2cd34f2..b165d20d0b4f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -520,6 +520,8 @@ object FunctionRegistry {
     expression[Ascii]("ascii"),
     expression[Chr]("char", true),
     expression[Chr]("chr"),
+    expressionBuilder("collate", CollateExpressionBuilder),
+    expression[Collation]("collation"),
     expressionBuilder("contains", ContainsExpressionBuilder),
     expressionBuilder("startswith", StartsWithExpressionBuilder),
     expressionBuilder("endswith", EndsWithExpressionBuilder),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
index 793dd373d689..45598b6a66f2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala
@@ -91,7 +91,7 @@ object EncoderUtils {
     case _: DayTimeIntervalType => classOf[java.lang.Long]
     case _: YearMonthIntervalType => classOf[java.lang.Integer]
     case BinaryType => classOf[Array[Byte]]
-    case StringType => classOf[UTF8String]
+    case _: StringType => classOf[UTF8String]
     case CalendarIntervalType => classOf[CalendarInterval]
     case _: StructType => classOf[InternalRow]
     case _: ArrayType => classOf[ArrayData]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
index 66a017578c35..18b64fd21338 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala
@@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
         IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, 
endField)))
     case _: DecimalType if useDecimalPlainString =>
       acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString))
-    case StringType => identity
+    case _: StringType => identity
     case _ => o => UTF8String.fromString(o.toString)
   }
 
@@ -257,7 +257,7 @@ trait ToStringBase { self: UnaryExpression with 
TimeZoneAwareExpression =>
       // notation if an exponent is needed.
       case _: DecimalType if useDecimalPlainString =>
         (c, evPrim) => code"$evPrim = 
UTF8String.fromString($c.toPlainString());"
-      case StringType =>
+      case _: StringType =>
         (c, evPrim) => code"$evPrim = $c;"
       case _ =>
         (c, evPrim) => code"$evPrim = 
UTF8String.fromString(String.valueOf($c));"
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
index 22ed8817ce3d..ea4307d572fc 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala
@@ -78,7 +78,7 @@ case class BloomFilterAggregate(
             "exprName" -> "estimatedNumItems or numBits"
           )
         )
-      case (LongType | IntegerType | ShortType | ByteType | StringType, 
LongType, LongType) =>
+      case (LongType | IntegerType | ShortType | ByteType | _: StringType, 
LongType, LongType) =>
         if (!estimatedNumItemsExpression.foldable) {
           DataTypeMismatch(
             errorSubClass = "NON_FOLDABLE_INPUT",
@@ -156,7 +156,7 @@ case class BloomFilterAggregate(
     case IntegerType => IntUpdater
     case ShortType => ShortUpdater
     case ByteType => ByteUpdater
-    case StringType => BinaryUpdater
+    case _: StringType => BinaryUpdater
   }
 
   override def first: Expression = child
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 4c1a86292d70..d922a960fcd8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -349,7 +349,7 @@ class CodegenContext extends Logging {
   def addBufferedState(dataType: DataType, variableName: String, initCode: 
String): ExprCode = {
     val value = addMutableState(javaType(dataType), variableName)
     val code = UserDefinedType.sqlType(dataType) match {
-      case StringType => code"$value = $initCode.clone();"
+      case _: StringType => code"$value = $initCode.clone();"
       case _: StructType | _: ArrayType | _: MapType => code"$value = 
$initCode.copy();"
       case _ => code"$value = $initCode;"
     }
@@ -622,6 +622,8 @@ class CodegenContext extends Logging {
       s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == 
$c2)"
     case DoubleType =>
       s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 
== $c2)"
+    case st: StringType if st.isDefaultCollation => s"$c1.binaryEquals($c2)"
+    case st: StringType => s"$c1.semanticEquals($c2, ${st.collationId})"
     case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
     case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
     case array: ArrayType => genComp(array, c1, c2) + " == 0"
@@ -650,7 +652,8 @@ class CodegenContext extends Logging {
     case FloatType =>
       val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$")
       s"$clsName.compareFloats($c1, $c2)"
-    // use c1 - c2 may overflow
+    case st: StringType if st.isDefaultCollation => s"$c1.binaryCompare($c2)"
+    case st: StringType => s"$c1.semanticCompare($c2, ${st.collationId})"
     case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? 
-1 : 0)"
     case BinaryType => 
s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)"
     case CalendarIntervalType => s"$c1.compareTo($c2)"
@@ -1716,7 +1719,7 @@ object CodeGenerator extends Logging {
       case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, 
value)
       // The UTF8String, InternalRow, ArrayData and MapData may came from 
UnsafeRow, we should copy
       // it to avoid keeping a "pointer" to a memory region which may get 
updated afterwards.
-      case StringType | _: StructType | _: ArrayType | _: MapType =>
+      case _: StringType | _: StructType | _: ArrayType | _: MapType =>
         s"$row.update($ordinal, $value.copy())"
       case _ => s"$row.update($ordinal, $value)"
     }
@@ -1769,7 +1772,7 @@ object CodeGenerator extends Logging {
         s"$vector.put${primitiveTypeName(jt)}($rowId, $value);"
       case t: DecimalType => s"$vector.putDecimal($rowId, $value, 
${t.precision});"
       case CalendarIntervalType => s"$vector.putInterval($rowId, $value);"
-      case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
+      case _: StringType => s"$vector.putByteArray($rowId, $value.getBytes());"
       case _ =>
         throw new SparkIllegalArgumentException(
           errorClass = "_LEGACY_ERROR_TEMP_3233",
@@ -1951,7 +1954,7 @@ object CodeGenerator extends Logging {
     case DoubleType => java.lang.Double.TYPE
     case _: DecimalType => classOf[Decimal]
     case BinaryType => classOf[Array[Byte]]
-    case StringType => classOf[UTF8String]
+    case _: StringType => classOf[UTF8String]
     case CalendarIntervalType => classOf[CalendarInterval]
     case _: StructType => classOf[InternalRow]
     case _: ArrayType => classOf[ArrayData]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
new file mode 100644
index 000000000000..a2faca95dfbc
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
+import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.types._
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr, collationName) - Marks a given expression with the 
specified collation.",
+  arguments = """
+    Arguments:
+      * expr - String expression to perform collation on.
+      * collationName - Foldable string expression that specifies the 
collation name.
+  """,
+  examples = """
+    Examples:
+      > SELECT COLLATION('Spark SQL' _FUNC_ 'UCS_BASIC_LCASE');
+       UCS_BASIC_LCASE
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+object CollateExpressionBuilder extends ExpressionBuilder {
+  override def build(funcName: String, expressions: Seq[Expression]): 
Expression = {
+    expressions match {
+      case Seq(e: Expression, collationExpr: Expression) =>
+        (collationExpr.dataType, collationExpr.foldable) match {
+          case (StringType, true) =>
+            val evalCollation = collationExpr.eval()
+            if (evalCollation == null) {
+              throw QueryCompilationErrors.unexpectedNullError("collation", 
collationExpr)
+            } else {
+              Collate(e, evalCollation.toString)
+            }
+          case (StringType, false) => throw 
QueryCompilationErrors.nonFoldableArgumentError(
+            funcName, "collationName", StringType)
+          case (_, _) => throw 
QueryCompilationErrors.unexpectedInputDataTypeError(
+            funcName, 1, StringType, collationExpr)
+        }
+      case s => throw QueryCompilationErrors.wrongNumArgsError(funcName, 
Seq(2), s.length)
+    }
+  }
+}
+
+/**
+ * An expression that marks a given expression with specified collation.
+ * This function is pass-through, it will not modify the input data.
+ * Only type metadata will be updated.
+ */
+case class Collate(child: Expression, collationName: String)
+  extends UnaryExpression with ExpectsInputTypes {
+  private val collationId = CollationFactory.collationNameToId(collationName)
+  override def dataType: DataType = StringType(collationId)
+  override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
+
+  override protected def withNewChildInternal(
+    newChild: Expression): Expression = copy(newChild)
+
+  override def eval(row: InternalRow): Any = child.eval(row)
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode =
+    defineCodeGen(ctx, ev, (in) => in)
+}
+
+@ExpressionDescription(
+  usage = "_FUNC_(expr) - Returns the collation name of a given expression.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_('Spark SQL');
+       UCS_BASIC
+  """,
+  since = "4.0.0",
+  group = "string_funcs")
+case class Collation(child: Expression) extends UnaryExpression with 
RuntimeReplaceable {
+  override def dataType: DataType = StringType
+  override protected def withNewChildInternal(newChild: Expression): Collation 
= copy(newChild)
+  override def replacement: Expression = {
+    val collationId = child.dataType.asInstanceOf[StringType].collationId
+    val collationName = 
CollationFactory.fetchCollation(collationId).collationName
+    Literal.create(collationName, StringType)
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 95bfba191d92..99486ae282a8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -2182,6 +2182,14 @@ class AstBuilder extends DataTypeAstBuilder with 
SQLConfHelper with Logging {
     }
   }
 
+  /**
+   * Create a [[Collate]] expression.
+   */
+  override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) 
{
+    val collation = string(visitStringLit(ctx.stringLit))
+    Collate(expression(ctx.primaryExpression), collation)
+  }
+
   /**
    * Create a [[Cast]] expression.
    */
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
index 5a3256a7915f..cc8008a9e11c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala
@@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag
 import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, 
InterpretedOrdering, SortOrder}
-import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, 
SQLOrderingUtil}
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, 
ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, 
DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, 
DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, 
IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, 
MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, 
StructField, StructType, TimestampNTZType, Timest [...]
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal}
@@ -263,7 +263,7 @@ case class PhysicalStringType(collationId: Int) extends 
PhysicalDataType {
   // this type. Otherwise, the companion object would be of type "StringType$" 
in byte code.
   // Defined with a private constructor so the companion object is the only 
possible instantiation.
   private[sql] type InternalType = UTF8String
-  private[sql] val ordering = implicitly[Ordering[InternalType]]
+  private[sql] val ordering = 
CollationFactory.fetchCollation(collationId).comparator.compare(_, _)
   @transient private[sql] lazy val tag = typeTag[InternalType]
 }
 object PhysicalStringType {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 265b0eeb8bdf..4df8d87074fc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.sql.Timestamp
 
-import scala.math.Ordering
-
 import org.apache.logging.log4j.Level
 
 import org.apache.spark.SparkFunSuite
@@ -558,14 +556,15 @@ class CodeGenerationSuite extends SparkFunSuite with 
ExpressionEvalHelper {
   test("SPARK-32624: CodegenContext.addReferenceObj should work for nested 
Scala class") {
     // emulate TypeUtils.getInterpretedOrdering(StringType)
     val ctx = new CodegenContext
-    val comparator = implicitly[Ordering[UTF8String]]
+    val comparator = implicitly[Ordering[String]]
+
     val refTerm = ctx.addReferenceObj("comparator", comparator)
 
     // Expecting result:
-    //   "((scala.math.LowPriorityOrderingImplicits$$anon$3) references[0] /* 
comparator */)"
+    //   "((scala.math.Ordering) references[0] /* comparator */)"
     // Using lenient assertions to be resilient to anonymous class numbering 
changes
     assert(!refTerm.contains("null"))
-    assert(refTerm.contains("scala.math.LowPriorityOrderingImplicits$$anon$"))
+    assert(refTerm.contains("scala.math.Ordering"))
   }
 
   test("SPARK-35578: final local variable bug in janino") {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
new file mode 100644
index 000000000000..35b0f7f5f326
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.types._
+
+class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper 
{
+  test("validate default collation") {
+    val collationId = CollationFactory.collationNameToId("UCS_BASIC")
+    assert(collationId == 0)
+    val collateExpr = Collate(Literal("abc"), "UCS_BASIC")
+    assert(collateExpr.dataType === StringType(collationId))
+    collateExpr.dataType.asInstanceOf[StringType].collationId == 0
+    checkEvaluation(collateExpr, "abc")
+  }
+
+  test("collate against literal") {
+    val collateExpr = Collate(Literal("abc"), "UCS_BASIC_LCASE")
+    val collationId = CollationFactory.collationNameToId("UCS_BASIC_LCASE")
+    assert(collateExpr.dataType == StringType(collationId))
+    checkEvaluation(collateExpr, "abc")
+  }
+
+  test("check input types") {
+    val collateExpr = Collate(Literal("abc"), "UCS_BASIC")
+    assert(collateExpr.checkInputDataTypes().isSuccess)
+
+    val collateExprExplicitDefault =
+      Collate(Literal.create("abc", StringType(0)), "UCS_BASIC")
+    assert(collateExprExplicitDefault.checkInputDataTypes().isSuccess)
+
+    val collateExprExplicitNonDefault =
+      Collate(Literal.create("abc", StringType(1)), "UCS_BASIC")
+    assert(collateExprExplicitNonDefault.checkInputDataTypes().isSuccess)
+
+    val collateOnNull = Collate(Literal.create(null, StringType(1)), 
"UCS_BASIC")
+    assert(collateOnNull.checkInputDataTypes().isSuccess)
+
+    val collateOnInt = Collate(Literal(1), "UCS_BASIC")
+    assert(collateOnInt.checkInputDataTypes().isFailure)
+  }
+
+  test("collate on non existing collation") {
+    checkError(
+      exception = intercept[SparkException] { Collate(Literal("abc"), 
"UCS_BASIS") },
+      errorClass = "COLLATION_INVALID_NAME",
+      sqlState = "42704",
+      parameters = Map("proposal" -> "UCS_BASIC", "collationName" -> 
"UCS_BASIS"))
+  }
+
+  test("collation on non-explicit default collation") {
+    checkEvaluation(Collation(Literal("abc")).replacement, "UCS_BASIC")
+  }
+
+  test("collation on explicitly collated string") {
+    checkEvaluation(Collation(Literal.create("abc", 
StringType(1))).replacement, "UCS_BASIC_LCASE")
+    checkEvaluation(
+      Collation(Collate(Literal("abc"), "UCS_BASIC_LCASE")).replacement, 
"UCS_BASIC_LCASE")
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
index 95b4b979348d..c59fd77c4bb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala
@@ -106,7 +106,7 @@ object HiveResult {
     case (bin: Array[Byte], BinaryType) => new String(bin, 
StandardCharsets.UTF_8)
     case (decimal: java.math.BigDecimal, DecimalType()) => 
decimal.toPlainString
     case (n, _: NumericType) => n.toString
-    case (s: String, StringType) => if (nested) "\"" + s + "\"" else s
+    case (s: String, _: StringType) => if (nested) "\"" + s + "\"" else s
     case (interval: CalendarInterval, CalendarIntervalType) => 
interval.toString
     case (seq: scala.collection.Seq[_], ArrayType(typ, _)) =>
       seq.map(v => (v, typ)).map(e => toHiveString(e, true, 
formatters)).mkString("[", ",", "]")
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
index 1f47673cbc2e..18ef84262aad 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala
@@ -270,8 +270,8 @@ private[columnar] final class StringColumnStats extends 
ColumnStats {
   }
 
   def gatherValueStats(value: UTF8String, size: Int): Unit = {
-    if (upper == null || value.compareTo(upper) > 0) upper = value.clone()
-    if (lower == null || value.compareTo(lower) < 0) lower = value.clone()
+    if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone()
+    if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone()
     sizeInBytes += size
     count += 1
   }
diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md 
b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
index f5bd0c8425d2..e20db3b49589 100644
--- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
+++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md
@@ -81,6 +81,8 @@
 | org.apache.spark.sql.catalyst.expressions.Chr | char | SELECT char(65) | 
struct<char(65):string> |
 | org.apache.spark.sql.catalyst.expressions.Chr | chr | SELECT chr(65) | 
struct<chr(65):string> |
 | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT 
coalesce(NULL, 1, NULL) | struct<coalesce(NULL, 1, NULL):int> |
+| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate 
| SELECT COLLATION('Spark SQL' collate 'UCS_BASIC_LCASE') | 
struct<collation(collate(Spark SQL)):string> |
+| org.apache.spark.sql.catalyst.expressions.Collation | collation | SELECT 
collation('Spark SQL') | struct<collation(Spark SQL):string> |
 | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT 
concat('Spark', 'SQL') | struct<concat(Spark, SQL):string> |
 | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT 
concat_ws(' ', 'Spark', 'SQL') | struct<concat_ws( , Spark, SQL):string> |
 | org.apache.spark.sql.catalyst.expressions.ContainsExpressionBuilder | 
contains | SELECT contains('Spark SQL', 'Spark') | struct<contains(Spark SQL, 
Spark):boolean> |
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
new file mode 100644
index 000000000000..13888272cad3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.catalyst.ExtendedAnalysisException
+import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.StringType
+
+class CollationSuite extends QueryTest with SharedSparkSession {
+  test("collate returns proper type") {
+    Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { 
collationName =>
+      checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa"))
+      val collationId = CollationFactory.collationNameToId(collationName)
+      assert(sql(s"select 'aaa' collate '$collationName'").schema(0).dataType
+        == StringType(collationId))
+    }
+  }
+
+  test("collation name is case insensitive") {
+    Seq("uCs_BasIc", "uCs_baSic_Lcase", "uNicOde", "UNICODE_ci").foreach { 
collationName =>
+      checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa"))
+      val collationId = CollationFactory.collationNameToId(collationName)
+      assert(sql(s"select 'aaa' collate '$collationName'").schema(0).dataType
+        == StringType(collationId))
+    }
+  }
+
+  test("collation expression returns name of collation") {
+    Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { 
collationName =>
+      checkAnswer(
+        sql(s"select collation('aaa' collate '$collationName')"), 
Row(collationName.toUpperCase()))
+    }
+  }
+
+  test("collate function syntax") {
+    assert(sql(s"select collate('aaa', 'ucs_basic')").schema(0).dataType == 
StringType(0))
+    assert(sql(s"select collate('aaa', 'ucs_basic_lcase')").schema(0).dataType 
== StringType(1))
+  }
+
+  test("collate function syntax invalid arg count") {
+    Seq("'aaa','a','b'", "'aaa'", "", "'aaa'").foreach(args => {
+      val paramCount = if (args == "") 0 else args.split(',').length.toString
+      checkError(
+        exception = intercept[AnalysisException] {
+          sql(s"select collate($args)")
+        },
+        errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+        sqlState = "42605",
+        parameters = Map(
+          "functionName" -> "`collate`",
+          "expectedNum" -> "2",
+          "actualNum" -> paramCount.toString,
+          "docroot" -> "https://spark.apache.org/docs/latest";),
+        context = ExpectedContext(fragment = s"collate($args)", start = 7, 
stop = 15 + args.length)
+      )
+    })
+  }
+
+  test("collate function invalid collation data type") {
+    checkError(
+      exception = intercept[AnalysisException](sql("select collate('abc', 
123)")),
+      errorClass = "UNEXPECTED_INPUT_TYPE",
+      sqlState = "42K09",
+      Map(
+        "functionName" -> "`collate`",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"123\"",
+        "inputType" -> "\"INT\"",
+        "requiredType" -> "\"STRING\""),
+      context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, 
stop = 25)
+    )
+  }
+
+  test("NULL as collation name") {
+    checkError(
+      exception = intercept[AnalysisException] {
+        sql("select collate('abc', cast(null as string))") },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL",
+      sqlState = "42K09",
+      Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS 
STRING)\""),
+      context = ExpectedContext(
+        fragment = s"collate('abc', cast(null as string))", start = 7, stop = 
42)
+    )
+  }
+
+  test("collate function invalid input data type") {
+    checkError(
+      exception = intercept[ExtendedAnalysisException] { sql(s"select 
collate(1, 'UCS_BASIC')") },
+      errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+      sqlState = "42K09",
+      parameters = Map(
+        "sqlExpr" -> "\"collate(1)\"",
+        "paramIndex" -> "1",
+        "inputSql" -> "\"1\"",
+        "inputType" -> "\"INT\"",
+        "requiredType" -> "\"STRING\""),
+      context = ExpectedContext(
+        fragment = s"collate(1, 'UCS_BASIC')", start = 7, stop = 29))
+  }
+
+  test("collation expression returns default collation") {
+    checkAnswer(sql(s"select collation('aaa')"), Row("UCS_BASIC"))
+  }
+
+  test("invalid collation name throws exception") {
+    checkError(
+      exception = intercept[SparkException] { sql("select 'aaa' collate 
'UCS_BASIS'") },
+      errorClass = "COLLATION_INVALID_NAME",
+      sqlState = "42704",
+      parameters = Map("proposal" -> "UCS_BASIC", "collationName" -> 
"UCS_BASIS"))
+  }
+
+  test("equality check respects collation") {
+    Seq(
+      ("ucs_basic", "aaa", "AAA", false),
+      ("ucs_basic", "aaa", "aaa", true),
+      ("ucs_basic_lcase", "aaa", "aaa", true),
+      ("ucs_basic_lcase", "aaa", "AAA", true),
+      ("ucs_basic_lcase", "aaa", "bbb", false),
+      ("unicode", "aaa", "aaa", true),
+      ("unicode", "aaa", "AAA", false),
+      ("unicode_CI", "aaa", "aaa", true),
+      ("unicode_CI", "aaa", "AAA", true),
+      ("unicode_CI", "aaa", "bbb", false)
+    ).foreach {
+      case (collationName, left, right, expected) =>
+        checkAnswer(
+          sql(s"select '$left' collate '$collationName' = '$right' collate 
'$collationName'"),
+          Row(expected))
+        checkAnswer(
+          sql(s"select collate('$left', '$collationName') = collate('$right', 
'$collationName')"),
+          Row(expected))
+    }
+  }
+
+  test("comparisons respect collation") {
+    Seq(
+      ("ucs_basic", "AAA", "aaa", true),
+      ("ucs_basic", "aaa", "aaa", false),
+      ("ucs_basic", "aaa", "BBB", false),
+      ("ucs_basic_lcase", "aaa", "aaa", false),
+      ("ucs_basic_lcase", "AAA", "aaa", false),
+      ("ucs_basic_lcase", "aaa", "bbb", true),
+      ("unicode", "aaa", "aaa", false),
+      ("unicode", "aaa", "AAA", true),
+      ("unicode", "aaa", "BBB", true),
+      ("unicode_CI", "aaa", "aaa", false),
+      ("unicode_CI", "aaa", "AAA", false),
+      ("unicode_CI", "aaa", "bbb", true)
+    ).foreach {
+      case (collationName, left, right, expected) =>
+        checkAnswer(
+          sql(s"select '$left' collate '$collationName' < '$right' collate 
'$collationName'"),
+          Row(expected))
+        checkAnswer(
+          sql(s"select collate('$left', '$collationName') < collate('$right', 
'$collationName')"),
+          Row(expected))
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 5168480f84b0..19251330cffe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -139,7 +139,10 @@ class ExpressionInfoSuite extends SparkFunSuite with 
SharedSparkSession {
       // Examples demonstrate alternative names, see SPARK-20749
       "org.apache.spark.sql.catalyst.expressions.Length",
       // Examples demonstrate alternative syntax, see SPARK-45574
-      "org.apache.spark.sql.catalyst.expressions.Cast")
+      "org.apache.spark.sql.catalyst.expressions.Cast",
+      // Examples demonstrate alternative syntax, see SPARK-47012
+      "org.apache.spark.sql.catalyst.expressions.Collate"
+    )
     spark.sessionState.functionRegistry.listFunction().foreach { funcId =>
       val info = spark.sessionState.catalog.lookupFunctionInfo(funcId)
       val className = info.getClassName


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to