This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 861cca3da4c4 [SPARK-46832][SQL] Introducing Collate and Collation expressions 861cca3da4c4 is described below commit 861cca3da4c446761ccff007c89b214a691b0a72 Author: Aleksandar Tomic <aleksandar.to...@databricks.com> AuthorDate: Wed Feb 14 19:14:50 2024 +0300 [SPARK-46832][SQL] Introducing Collate and Collation expressions ### What changes were proposed in this pull request? This PR adds E2E support for `collate` and `collation` expressions. Following changes were made to get us there: 1) Set the right ordering for `PhysicalStringType` based on `collationId`. 2) UTF8String is now just a data holder class - it no longer implements `Comparable` interface. All comparisons must be done through `CollationFactory`. 3) `collate` and `collation` expressions are added. Special syntax for `collate` is enabled - `'hello world' COLLATE 'target_collation' 4) First set of tests is added that covers both core expression and E2E collation tests. ### Why are the changes needed? This PR is part of larger collation track. For more details please refer to design doc attached in parent JIRA ticket. ### Does this PR introduce _any_ user-facing change? This test adds two new expressions and opens up new syntax. ### How was this patch tested? Basic tests are added. In follow up PRs we will add support for more advanced operators and keep adding tests alongside new feature support. ### Was this patch authored or co-authored using generative AI tooling? Yes. Closes #45064 from dbatomic/stringtype_compare. Lead-authored-by: Aleksandar Tomic <aleksandar.to...@databricks.com> Co-authored-by: Stefan Kandic <stefan.kan...@databricks.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- .../spark/sql/catalyst/util/CollationFactory.java | 5 +- .../org/apache/spark/unsafe/types/UTF8String.java | 59 ++++++- .../apache/spark/unsafe/types/UTF8StringSuite.java | 24 +-- .../types/UTF8StringPropertyCheckSuite.scala | 2 +- .../spark/sql/catalyst/parser/SqlBaseParser.g4 | 1 + .../spark/sql/catalyst/encoders/RowEncoder.scala | 2 +- .../org/apache/spark/sql/types/StringType.scala | 23 ++- .../sql/catalyst/CatalystTypeConverters.scala | 2 +- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/encoders/EncoderUtils.scala | 2 +- .../sql/catalyst/expressions/ToStringBase.scala | 4 +- .../aggregate/BloomFilterAggregate.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 13 +- .../expressions/collationExpressions.scala | 100 ++++++++++++ .../spark/sql/catalyst/parser/AstBuilder.scala | 8 + .../sql/catalyst/types/PhysicalDataType.scala | 4 +- .../catalyst/expressions/CodeGenerationSuite.scala | 9 +- .../expressions/CollationExpressionSuite.scala | 77 +++++++++ .../apache/spark/sql/execution/HiveResult.scala | 2 +- .../spark/sql/execution/columnar/ColumnStats.scala | 4 +- .../sql-functions/sql-expression-schema.md | 2 + .../org/apache/spark/sql/CollationSuite.scala | 177 +++++++++++++++++++++ .../sql/expressions/ExpressionInfoSuite.scala | 5 +- 23 files changed, 484 insertions(+), 47 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 018fb6cbeb9f..83cac849e848 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -112,7 +112,7 @@ public final class CollationFactory { collationTable[0] = new Collation( "UCS_BASIC", null, - UTF8String::compareTo, + UTF8String::binaryCompare, "1.0", s -> (long)s.hashCode(), true); @@ -122,7 +122,7 @@ public final class CollationFactory { collationTable[1] = new Collation( "UCS_BASIC_LCASE", null, - Comparator.comparing(UTF8String::toLowerCase), + (s1, s2) -> s1.toLowerCase().binaryCompare(s2.toLowerCase()), "1.0", (s) -> (long)s.toLowerCase().hashCode(), false); @@ -132,7 +132,6 @@ public final class CollationFactory { "UNICODE", Collator.getInstance(ULocale.ROOT), "153.120.0.0", true); collationTable[2].collator.setStrength(Collator.TERTIARY); - // UNICODE case-insensitive comparison (ROOT locale, in ICU + Secondary strength). collationTable[3] = new Collation( "UNICODE_CI", Collator.getInstance(ULocale.ROOT), "153.120.0.0", false); diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 21c31c954ba3..bb794446472f 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -30,12 +30,14 @@ import com.esotericsoftware.kryo.KryoSerializable; import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UTF8StringBuilder; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.hash.Murmur3_x86_32; import static org.apache.spark.unsafe.Platform.*; +import org.apache.spark.util.SparkEnvUtils$; /** @@ -1388,28 +1390,71 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable, return fromBytes(bytes); } + /** + * Implementation of Comparable interface. This method is kept for backwards compatibility. + * It should not be used in spark code base, given that string comparison requires passing + * collation id. Either explicitly use `binaryCompare` or use `semanticCompare`. + */ @Override public int compareTo(@Nonnull final UTF8String other) { + if (SparkEnvUtils$.MODULE$.isTesting()) { + throw new UnsupportedOperationException( + "compareTo should not be used in spark code base. Use binaryCompare or semanticCompare."); + } else { + return binaryCompare(other); + } + } + + /** + * Binary comparison of two UTF8String. Can only be used for default UCS_BASIC collation. + */ + public int binaryCompare(final UTF8String other) { return ByteArray.compareBinary( - base, offset, numBytes, other.base, other.offset, other.numBytes); + base, offset, numBytes, other.base, other.offset, other.numBytes); } - public int compare(final UTF8String other) { - return compareTo(other); + /** + * Collation-aware comparison of two UTF8String. The collation to use is specified by the + * `collationId` parameter. + */ + public int semanticCompare(final UTF8String other, int collationId) { + return CollationFactory.fetchCollation(collationId).comparator.compare(this, other); } + /** + * Binary equality check of two UTF8String. Note that binary equality is not the same as + * equality under given collation. E.g. if string is collated in case-insensitive two strings + * are considered equal even if they are different in binary comparison. + */ @Override public boolean equals(final Object other) { if (other instanceof UTF8String o) { - if (numBytes != o.numBytes) { - return false; - } - return ByteArrayMethods.arrayEquals(base, offset, o.base, o.offset, numBytes); + return binaryEquals(o); } else { return false; } } + /** + * Binary equality check of two UTF8String. Note that binary equality is not the same as + * equality under given collation. E.g. if string is collated in case-insensitive two strings + * are considered equal even if they are different in binary comparison. + */ + public boolean binaryEquals(final UTF8String other) { + if (numBytes != other.numBytes) { + return false; + } + + return ByteArrayMethods.arrayEquals(base, offset, other.base, other.offset, numBytes); + } + + /** + * Collation-aware equality comparison of two UTF8String. + */ + public boolean semanticEquals(final UTF8String other, int collationId) { + return CollationFactory.fetchCollation(collationId).equalsFunction.apply(this, other); + } + /** * Levenshtein distance is a metric for measuring the distance of two strings. The distance is * defined by the minimum number of single-character edits (i.e. insertions, deletions or diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index bca45b0764c9..594b96944934 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -47,7 +47,7 @@ public class UTF8StringSuite { assertEquals(s1.hashCode(), s2.hashCode()); - assertEquals(0, s1.compareTo(s2)); + assertEquals(0, s1.binaryCompare(s2)); assertTrue(s1.contains(s2)); assertTrue(s2.contains(s1)); @@ -93,18 +93,18 @@ public class UTF8StringSuite { } @Test - public void compareTo() { - assertTrue(fromString("").compareTo(fromString("a")) < 0); - assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); - assertTrue(fromString("abc0").compareTo(fromString("abc")) > 0); - assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabc")) == 0); - assertTrue(fromString("aBcabcabc").compareTo(fromString("Abcabcabc")) > 0); - assertTrue(fromString("Abcabcabc").compareTo(fromString("abcabcabC")) < 0); - assertTrue(fromString("abcabcabc").compareTo(fromString("abcabcabC")) > 0); + public void binaryCompareTo() { + assertTrue(fromString("").binaryCompare(fromString("a")) < 0); + assertTrue(fromString("abc").binaryCompare(fromString("ABC")) > 0); + assertTrue(fromString("abc0").binaryCompare(fromString("abc")) > 0); + assertTrue(fromString("abcabcabc").binaryCompare(fromString("abcabcabc")) == 0); + assertTrue(fromString("aBcabcabc").binaryCompare(fromString("Abcabcabc")) > 0); + assertTrue(fromString("Abcabcabc").binaryCompare(fromString("abcabcabC")) < 0); + assertTrue(fromString("abcabcabc").binaryCompare(fromString("abcabcabC")) > 0); - assertTrue(fromString("abc").compareTo(fromString("世界")) < 0); - assertTrue(fromString("你好").compareTo(fromString("世界")) > 0); - assertTrue(fromString("你好123").compareTo(fromString("你好122")) > 0); + assertTrue(fromString("abc").binaryCompare(fromString("世界")) < 0); + assertTrue(fromString("你好").binaryCompare(fromString("世界")) > 0); + assertTrue(fromString("你好123").binaryCompare(fromString("你好122")) > 0); } protected static void testUpperandLower(String upper, String lower) { diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 75c56451592e..3f02d7261112 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -81,7 +81,7 @@ class UTF8StringPropertyCheckSuite extends AnyFunSuite with ScalaCheckDrivenProp test("compare") { forAll { (s1: String, s2: String) => assert(Math.signum { - toUTF8(s1).compareTo(toUTF8(s2)).toFloat + toUTF8(s1).binaryCompare(toUTF8(s2)).toFloat } === Math.signum(s1.compareTo(s2).toFloat)) } } diff --git a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index 737d5196e7c4..1109e4a7bdfc 100644 --- a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -989,6 +989,7 @@ primaryExpression | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase | CASE value=expression whenClause+ (ELSE elseExpression=expression)? END #simpleCase | name=(CAST | TRY_CAST) LEFT_PAREN expression AS dataType RIGHT_PAREN #cast + | primaryExpression COLLATE stringLit #collate | primaryExpression DOUBLE_COLON dataType #castByColon | STRUCT LEFT_PAREN (argument+=namedExpression (COMMA argument+=namedExpression)*)? RIGHT_PAREN #struct | FIRST LEFT_PAREN expression (IGNORE NULLS)? RIGHT_PAREN #first diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index a201da9c95c9..16ac283eccb1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -81,7 +81,7 @@ object RowEncoder { case DoubleType => BoxedDoubleEncoder case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization = true) case BinaryType => BinaryEncoder - case StringType => StringEncoder + case _: StringType => StringEncoder case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled => InstantEncoder(lenient) case TimestampType => TimestampEncoder(lenient) case TimestampNTZType => LocalDateTimeEncoder diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index bd2ff8475741..501d86433847 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.types import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.util.CollationFactory /** * The data type representing `String` values. Please use the singleton `DataTypes.StringType`. @@ -26,7 +27,27 @@ import org.apache.spark.annotation.Stable * @param collationId The id of collation for this StringType. */ @Stable -class StringType private(val collationId: Int) extends AtomicType { +class StringType private(val collationId: Int) extends AtomicType with Serializable { + /** + * Returns whether assigned collation is the default spark collation (UCS_BASIC). + */ + def isDefaultCollation: Boolean = collationId == StringType.DEFAULT_COLLATION_ID + + /** + * Type name that is shown to the customer. + * If this is an UCS_BASIC collation output is `string` due to backwards compatibility. + */ + override def typeName: String = + if (isDefaultCollation) "string" + else s"string(${CollationFactory.fetchCollation(collationId).collationName})" + + override def equals(obj: Any): Boolean = + obj.isInstanceOf[StringType] && obj.asInstanceOf[StringType].collationId == collationId + + override def hashCode(): Int = collationId.hashCode() + + override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[StringType] + /** * The default size of a value of the StringType is 20 bytes. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 32e0c5884ebe..2b2a186f76d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -66,7 +66,7 @@ object CatalystTypeConverters { case arrayType: ArrayType => ArrayConverter(arrayType.elementType) case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType) case structType: StructType => StructConverter(structType) - case StringType => StringConverter + case _: StringType => StringConverter case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter case DateType => DateConverter case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 391ff2cd34f2..b165d20d0b4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -520,6 +520,8 @@ object FunctionRegistry { expression[Ascii]("ascii"), expression[Chr]("char", true), expression[Chr]("chr"), + expressionBuilder("collate", CollateExpressionBuilder), + expression[Collation]("collation"), expressionBuilder("contains", ContainsExpressionBuilder), expressionBuilder("startswith", StartsWithExpressionBuilder), expressionBuilder("endswith", EndsWithExpressionBuilder), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala index 793dd373d689..45598b6a66f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/EncoderUtils.scala @@ -91,7 +91,7 @@ object EncoderUtils { case _: DayTimeIntervalType => classOf[java.lang.Long] case _: YearMonthIntervalType => classOf[java.lang.Integer] case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] + case _: StringType => classOf[UTF8String] case CalendarIntervalType => classOf[CalendarInterval] case _: StructType => classOf[InternalRow] case _: ArrayType => classOf[ArrayData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 66a017578c35..18b64fd21338 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -162,7 +162,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField))) case _: DecimalType if useDecimalPlainString => acceptAny[Decimal](d => UTF8String.fromString(d.toPlainString)) - case StringType => identity + case _: StringType => identity case _ => o => UTF8String.fromString(o.toString) } @@ -257,7 +257,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => // notation if an exponent is needed. case _: DecimalType if useDecimalPlainString => (c, evPrim) => code"$evPrim = UTF8String.fromString($c.toPlainString());" - case StringType => + case _: StringType => (c, evPrim) => code"$evPrim = $c;" case _ => (c, evPrim) => code"$evPrim = UTF8String.fromString(String.valueOf($c));" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala index 22ed8817ce3d..ea4307d572fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/BloomFilterAggregate.scala @@ -78,7 +78,7 @@ case class BloomFilterAggregate( "exprName" -> "estimatedNumItems or numBits" ) ) - case (LongType | IntegerType | ShortType | ByteType | StringType, LongType, LongType) => + case (LongType | IntegerType | ShortType | ByteType | _: StringType, LongType, LongType) => if (!estimatedNumItemsExpression.foldable) { DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", @@ -156,7 +156,7 @@ case class BloomFilterAggregate( case IntegerType => IntUpdater case ShortType => ShortUpdater case ByteType => ByteUpdater - case StringType => BinaryUpdater + case _: StringType => BinaryUpdater } override def first: Expression = child diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 4c1a86292d70..d922a960fcd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -349,7 +349,7 @@ class CodegenContext extends Logging { def addBufferedState(dataType: DataType, variableName: String, initCode: String): ExprCode = { val value = addMutableState(javaType(dataType), variableName) val code = UserDefinedType.sqlType(dataType) match { - case StringType => code"$value = $initCode.clone();" + case _: StringType => code"$value = $initCode.clone();" case _: StructType | _: ArrayType | _: MapType => code"$value = $initCode.copy();" case _ => code"$value = $initCode;" } @@ -622,6 +622,8 @@ class CodegenContext extends Logging { s"((java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2)" case DoubleType => s"((java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2)" + case st: StringType if st.isDefaultCollation => s"$c1.binaryEquals($c2)" + case st: StringType => s"$c1.semanticEquals($c2, ${st.collationId})" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" @@ -650,7 +652,8 @@ class CodegenContext extends Logging { case FloatType => val clsName = SQLOrderingUtil.getClass.getName.stripSuffix("$") s"$clsName.compareFloats($c1, $c2)" - // use c1 - c2 may overflow + case st: StringType if st.isDefaultCollation => s"$c1.binaryCompare($c2)" + case st: StringType => s"$c1.semanticCompare($c2, ${st.collationId})" case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.unsafe.types.ByteArray.compareBinary($c1, $c2)" case CalendarIntervalType => s"$c1.compareTo($c2)" @@ -1716,7 +1719,7 @@ object CodeGenerator extends Logging { case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) // The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy // it to avoid keeping a "pointer" to a memory region which may get updated afterwards. - case StringType | _: StructType | _: ArrayType | _: MapType => + case _: StringType | _: StructType | _: ArrayType | _: MapType => s"$row.update($ordinal, $value.copy())" case _ => s"$row.update($ordinal, $value)" } @@ -1769,7 +1772,7 @@ object CodeGenerator extends Logging { s"$vector.put${primitiveTypeName(jt)}($rowId, $value);" case t: DecimalType => s"$vector.putDecimal($rowId, $value, ${t.precision});" case CalendarIntervalType => s"$vector.putInterval($rowId, $value);" - case t: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" + case _: StringType => s"$vector.putByteArray($rowId, $value.getBytes());" case _ => throw new SparkIllegalArgumentException( errorClass = "_LEGACY_ERROR_TEMP_3233", @@ -1951,7 +1954,7 @@ object CodeGenerator extends Logging { case DoubleType => java.lang.Double.TYPE case _: DecimalType => classOf[Decimal] case BinaryType => classOf[Array[Byte]] - case StringType => classOf[UTF8String] + case _: StringType => classOf[UTF8String] case CalendarIntervalType => classOf[CalendarInterval] case _: StructType => classOf[InternalRow] case _: ArrayType => classOf[ArrayData] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala new file mode 100644 index 000000000000..a2faca95dfbc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types._ + +@ExpressionDescription( + usage = "_FUNC_(expr, collationName) - Marks a given expression with the specified collation.", + arguments = """ + Arguments: + * expr - String expression to perform collation on. + * collationName - Foldable string expression that specifies the collation name. + """, + examples = """ + Examples: + > SELECT COLLATION('Spark SQL' _FUNC_ 'UCS_BASIC_LCASE'); + UCS_BASIC_LCASE + """, + since = "4.0.0", + group = "string_funcs") +object CollateExpressionBuilder extends ExpressionBuilder { + override def build(funcName: String, expressions: Seq[Expression]): Expression = { + expressions match { + case Seq(e: Expression, collationExpr: Expression) => + (collationExpr.dataType, collationExpr.foldable) match { + case (StringType, true) => + val evalCollation = collationExpr.eval() + if (evalCollation == null) { + throw QueryCompilationErrors.unexpectedNullError("collation", collationExpr) + } else { + Collate(e, evalCollation.toString) + } + case (StringType, false) => throw QueryCompilationErrors.nonFoldableArgumentError( + funcName, "collationName", StringType) + case (_, _) => throw QueryCompilationErrors.unexpectedInputDataTypeError( + funcName, 1, StringType, collationExpr) + } + case s => throw QueryCompilationErrors.wrongNumArgsError(funcName, Seq(2), s.length) + } + } +} + +/** + * An expression that marks a given expression with specified collation. + * This function is pass-through, it will not modify the input data. + * Only type metadata will be updated. + */ +case class Collate(child: Expression, collationName: String) + extends UnaryExpression with ExpectsInputTypes { + private val collationId = CollationFactory.collationNameToId(collationName) + override def dataType: DataType = StringType(collationId) + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override protected def withNewChildInternal( + newChild: Expression): Expression = copy(newChild) + + override def eval(row: InternalRow): Any = child.eval(row) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, (in) => in) +} + +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the collation name of a given expression.", + examples = """ + Examples: + > SELECT _FUNC_('Spark SQL'); + UCS_BASIC + """, + since = "4.0.0", + group = "string_funcs") +case class Collation(child: Expression) extends UnaryExpression with RuntimeReplaceable { + override def dataType: DataType = StringType + override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild) + override def replacement: Expression = { + val collationId = child.dataType.asInstanceOf[StringType].collationId + val collationName = CollationFactory.fetchCollation(collationId).collationName + Literal.create(collationName, StringType) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 95bfba191d92..99486ae282a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -2182,6 +2182,14 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { } } + /** + * Create a [[Collate]] expression. + */ + override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) { + val collation = string(visitStringLit(ctx.stringLit)) + Collate(expression(ctx.primaryExpression), collation) + } + /** * Create a [[Cast]] expression. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index 5a3256a7915f..cc8008a9e11c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -21,7 +21,7 @@ import scala.reflect.runtime.universe.TypeTag import scala.reflect.runtime.universe.typeTag import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, InterpretedOrdering, SortOrder} -import org.apache.spark.sql.catalyst.util.{ArrayData, SQLOrderingUtil} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, SQLOrderingUtil} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteExactNumeric, ByteType, CalendarIntervalType, CharType, DataType, DateType, DayTimeIntervalType, Decimal, DecimalExactNumeric, DecimalType, DoubleExactNumeric, DoubleType, FloatExactNumeric, FloatType, FractionalType, IntegerExactNumeric, IntegerType, IntegralType, LongExactNumeric, LongType, MapType, NullType, NumericType, ShortExactNumeric, ShortType, StringType, StructField, StructType, TimestampNTZType, Timest [...] import org.apache.spark.unsafe.types.{ByteArray, UTF8String, VariantVal} @@ -263,7 +263,7 @@ case class PhysicalStringType(collationId: Int) extends PhysicalDataType { // this type. Otherwise, the companion object would be of type "StringType$" in byte code. // Defined with a private constructor so the companion object is the only possible instantiation. private[sql] type InternalType = UTF8String - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = CollationFactory.fetchCollation(collationId).comparator.compare(_, _) @transient private[sql] lazy val tag = typeTag[InternalType] } object PhysicalStringType { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 265b0eeb8bdf..4df8d87074fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.Timestamp -import scala.math.Ordering - import org.apache.logging.log4j.Level import org.apache.spark.SparkFunSuite @@ -558,14 +556,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-32624: CodegenContext.addReferenceObj should work for nested Scala class") { // emulate TypeUtils.getInterpretedOrdering(StringType) val ctx = new CodegenContext - val comparator = implicitly[Ordering[UTF8String]] + val comparator = implicitly[Ordering[String]] + val refTerm = ctx.addReferenceObj("comparator", comparator) // Expecting result: - // "((scala.math.LowPriorityOrderingImplicits$$anon$3) references[0] /* comparator */)" + // "((scala.math.Ordering) references[0] /* comparator */)" // Using lenient assertions to be resilient to anonymous class numbering changes assert(!refTerm.contains("null")) - assert(refTerm.contains("scala.math.LowPriorityOrderingImplicits$$anon$")) + assert(refTerm.contains("scala.math.Ordering")) } test("SPARK-35578: final local variable bug in janino") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala new file mode 100644 index 000000000000..35b0f7f5f326 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types._ + +class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("validate default collation") { + val collationId = CollationFactory.collationNameToId("UCS_BASIC") + assert(collationId == 0) + val collateExpr = Collate(Literal("abc"), "UCS_BASIC") + assert(collateExpr.dataType === StringType(collationId)) + collateExpr.dataType.asInstanceOf[StringType].collationId == 0 + checkEvaluation(collateExpr, "abc") + } + + test("collate against literal") { + val collateExpr = Collate(Literal("abc"), "UCS_BASIC_LCASE") + val collationId = CollationFactory.collationNameToId("UCS_BASIC_LCASE") + assert(collateExpr.dataType == StringType(collationId)) + checkEvaluation(collateExpr, "abc") + } + + test("check input types") { + val collateExpr = Collate(Literal("abc"), "UCS_BASIC") + assert(collateExpr.checkInputDataTypes().isSuccess) + + val collateExprExplicitDefault = + Collate(Literal.create("abc", StringType(0)), "UCS_BASIC") + assert(collateExprExplicitDefault.checkInputDataTypes().isSuccess) + + val collateExprExplicitNonDefault = + Collate(Literal.create("abc", StringType(1)), "UCS_BASIC") + assert(collateExprExplicitNonDefault.checkInputDataTypes().isSuccess) + + val collateOnNull = Collate(Literal.create(null, StringType(1)), "UCS_BASIC") + assert(collateOnNull.checkInputDataTypes().isSuccess) + + val collateOnInt = Collate(Literal(1), "UCS_BASIC") + assert(collateOnInt.checkInputDataTypes().isFailure) + } + + test("collate on non existing collation") { + checkError( + exception = intercept[SparkException] { Collate(Literal("abc"), "UCS_BASIS") }, + errorClass = "COLLATION_INVALID_NAME", + sqlState = "42704", + parameters = Map("proposal" -> "UCS_BASIC", "collationName" -> "UCS_BASIS")) + } + + test("collation on non-explicit default collation") { + checkEvaluation(Collation(Literal("abc")).replacement, "UCS_BASIC") + } + + test("collation on explicitly collated string") { + checkEvaluation(Collation(Literal.create("abc", StringType(1))).replacement, "UCS_BASIC_LCASE") + checkEvaluation( + Collation(Collate(Literal("abc"), "UCS_BASIC_LCASE")).replacement, "UCS_BASIC_LCASE") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 95b4b979348d..c59fd77c4bb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -106,7 +106,7 @@ object HiveResult { case (bin: Array[Byte], BinaryType) => new String(bin, StandardCharsets.UTF_8) case (decimal: java.math.BigDecimal, DecimalType()) => decimal.toPlainString case (n, _: NumericType) => n.toString - case (s: String, StringType) => if (nested) "\"" + s + "\"" else s + case (s: String, _: StringType) => if (nested) "\"" + s + "\"" else s case (interval: CalendarInterval, CalendarIntervalType) => interval.toString case (seq: scala.collection.Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(e => toHiveString(e, true, formatters)).mkString("[", ",", "]") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index 1f47673cbc2e..18ef84262aad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -270,8 +270,8 @@ private[columnar] final class StringColumnStats extends ColumnStats { } def gatherValueStats(value: UTF8String, size: Int): Unit = { - if (upper == null || value.compareTo(upper) > 0) upper = value.clone() - if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + if (upper == null || value.binaryCompare(upper) > 0) upper = value.clone() + if (lower == null || value.binaryCompare(lower) < 0) lower = value.clone() sizeInBytes += size count += 1 } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index f5bd0c8425d2..e20db3b49589 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -81,6 +81,8 @@ | org.apache.spark.sql.catalyst.expressions.Chr | char | SELECT char(65) | struct<char(65):string> | | org.apache.spark.sql.catalyst.expressions.Chr | chr | SELECT chr(65) | struct<chr(65):string> | | org.apache.spark.sql.catalyst.expressions.Coalesce | coalesce | SELECT coalesce(NULL, 1, NULL) | struct<coalesce(NULL, 1, NULL):int> | +| org.apache.spark.sql.catalyst.expressions.CollateExpressionBuilder | collate | SELECT COLLATION('Spark SQL' collate 'UCS_BASIC_LCASE') | struct<collation(collate(Spark SQL)):string> | +| org.apache.spark.sql.catalyst.expressions.Collation | collation | SELECT collation('Spark SQL') | struct<collation(Spark SQL):string> | | org.apache.spark.sql.catalyst.expressions.Concat | concat | SELECT concat('Spark', 'SQL') | struct<concat(Spark, SQL):string> | | org.apache.spark.sql.catalyst.expressions.ConcatWs | concat_ws | SELECT concat_ws(' ', 'Spark', 'SQL') | struct<concat_ws( , Spark, SQL):string> | | org.apache.spark.sql.catalyst.expressions.ContainsExpressionBuilder | contains | SELECT contains('Spark SQL', 'Spark') | struct<contains(Spark SQL, Spark):boolean> | diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala new file mode 100644 index 000000000000..13888272cad3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StringType + +class CollationSuite extends QueryTest with SharedSparkSession { + test("collate returns proper type") { + Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { collationName => + checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa")) + val collationId = CollationFactory.collationNameToId(collationName) + assert(sql(s"select 'aaa' collate '$collationName'").schema(0).dataType + == StringType(collationId)) + } + } + + test("collation name is case insensitive") { + Seq("uCs_BasIc", "uCs_baSic_Lcase", "uNicOde", "UNICODE_ci").foreach { collationName => + checkAnswer(sql(s"select 'aaa' collate '$collationName'"), Row("aaa")) + val collationId = CollationFactory.collationNameToId(collationName) + assert(sql(s"select 'aaa' collate '$collationName'").schema(0).dataType + == StringType(collationId)) + } + } + + test("collation expression returns name of collation") { + Seq("ucs_basic", "ucs_basic_lcase", "unicode", "unicode_ci").foreach { collationName => + checkAnswer( + sql(s"select collation('aaa' collate '$collationName')"), Row(collationName.toUpperCase())) + } + } + + test("collate function syntax") { + assert(sql(s"select collate('aaa', 'ucs_basic')").schema(0).dataType == StringType(0)) + assert(sql(s"select collate('aaa', 'ucs_basic_lcase')").schema(0).dataType == StringType(1)) + } + + test("collate function syntax invalid arg count") { + Seq("'aaa','a','b'", "'aaa'", "", "'aaa'").foreach(args => { + val paramCount = if (args == "") 0 else args.split(',').length.toString + checkError( + exception = intercept[AnalysisException] { + sql(s"select collate($args)") + }, + errorClass = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION", + sqlState = "42605", + parameters = Map( + "functionName" -> "`collate`", + "expectedNum" -> "2", + "actualNum" -> paramCount.toString, + "docroot" -> "https://spark.apache.org/docs/latest"), + context = ExpectedContext(fragment = s"collate($args)", start = 7, stop = 15 + args.length) + ) + }) + } + + test("collate function invalid collation data type") { + checkError( + exception = intercept[AnalysisException](sql("select collate('abc', 123)")), + errorClass = "UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + Map( + "functionName" -> "`collate`", + "paramIndex" -> "1", + "inputSql" -> "\"123\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext(fragment = s"collate('abc', 123)", start = 7, stop = 25) + ) + } + + test("NULL as collation name") { + checkError( + exception = intercept[AnalysisException] { + sql("select collate('abc', cast(null as string))") }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_NULL", + sqlState = "42K09", + Map("exprName" -> "`collation`", "sqlExpr" -> "\"CAST(NULL AS STRING)\""), + context = ExpectedContext( + fragment = s"collate('abc', cast(null as string))", start = 7, stop = 42) + ) + } + + test("collate function invalid input data type") { + checkError( + exception = intercept[ExtendedAnalysisException] { sql(s"select collate(1, 'UCS_BASIC')") }, + errorClass = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE", + sqlState = "42K09", + parameters = Map( + "sqlExpr" -> "\"collate(1)\"", + "paramIndex" -> "1", + "inputSql" -> "\"1\"", + "inputType" -> "\"INT\"", + "requiredType" -> "\"STRING\""), + context = ExpectedContext( + fragment = s"collate(1, 'UCS_BASIC')", start = 7, stop = 29)) + } + + test("collation expression returns default collation") { + checkAnswer(sql(s"select collation('aaa')"), Row("UCS_BASIC")) + } + + test("invalid collation name throws exception") { + checkError( + exception = intercept[SparkException] { sql("select 'aaa' collate 'UCS_BASIS'") }, + errorClass = "COLLATION_INVALID_NAME", + sqlState = "42704", + parameters = Map("proposal" -> "UCS_BASIC", "collationName" -> "UCS_BASIS")) + } + + test("equality check respects collation") { + Seq( + ("ucs_basic", "aaa", "AAA", false), + ("ucs_basic", "aaa", "aaa", true), + ("ucs_basic_lcase", "aaa", "aaa", true), + ("ucs_basic_lcase", "aaa", "AAA", true), + ("ucs_basic_lcase", "aaa", "bbb", false), + ("unicode", "aaa", "aaa", true), + ("unicode", "aaa", "AAA", false), + ("unicode_CI", "aaa", "aaa", true), + ("unicode_CI", "aaa", "AAA", true), + ("unicode_CI", "aaa", "bbb", false) + ).foreach { + case (collationName, left, right, expected) => + checkAnswer( + sql(s"select '$left' collate '$collationName' = '$right' collate '$collationName'"), + Row(expected)) + checkAnswer( + sql(s"select collate('$left', '$collationName') = collate('$right', '$collationName')"), + Row(expected)) + } + } + + test("comparisons respect collation") { + Seq( + ("ucs_basic", "AAA", "aaa", true), + ("ucs_basic", "aaa", "aaa", false), + ("ucs_basic", "aaa", "BBB", false), + ("ucs_basic_lcase", "aaa", "aaa", false), + ("ucs_basic_lcase", "AAA", "aaa", false), + ("ucs_basic_lcase", "aaa", "bbb", true), + ("unicode", "aaa", "aaa", false), + ("unicode", "aaa", "AAA", true), + ("unicode", "aaa", "BBB", true), + ("unicode_CI", "aaa", "aaa", false), + ("unicode_CI", "aaa", "AAA", false), + ("unicode_CI", "aaa", "bbb", true) + ).foreach { + case (collationName, left, right, expected) => + checkAnswer( + sql(s"select '$left' collate '$collationName' < '$right' collate '$collationName'"), + Row(expected)) + checkAnswer( + sql(s"select collate('$left', '$collationName') < collate('$right', '$collationName')"), + Row(expected)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala index 5168480f84b0..19251330cffe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala @@ -139,7 +139,10 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession { // Examples demonstrate alternative names, see SPARK-20749 "org.apache.spark.sql.catalyst.expressions.Length", // Examples demonstrate alternative syntax, see SPARK-45574 - "org.apache.spark.sql.catalyst.expressions.Cast") + "org.apache.spark.sql.catalyst.expressions.Cast", + // Examples demonstrate alternative syntax, see SPARK-47012 + "org.apache.spark.sql.catalyst.expressions.Collate" + ) spark.sessionState.functionRegistry.listFunction().foreach { funcId => val info = spark.sessionState.catalog.lookupFunctionInfo(funcId) val className = info.getClassName --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org