This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new afe310d617e5 [SPARK-47351][SQL] Add collation support for StringToMap & Mask string expressions afe310d617e5 is described below commit afe310d617e5d5e1fd79e7d42e2bbafe93c6d3a8 Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Fri Apr 26 20:33:29 2024 +0800 [SPARK-47351][SQL] Add collation support for StringToMap & Mask string expressions ### What changes were proposed in this pull request? Introduce collation awareness for string expressions: str_to_map & mask. ### Why are the changes needed? Add collation support for built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for built-in string functions: str_to_map & mask. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46165 from uros-db/SPARK-47351. Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/analysis/CollationTypeCasts.scala | 2 +- .../catalyst/expressions/complexTypeCreator.scala | 8 +- .../sql/catalyst/expressions/maskExpressions.scala | 44 +++++----- .../spark/sql/CollationSQLExpressionsSuite.scala | 98 ++++++++++++++++++++++ 4 files changed, 129 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala index 473d552b3d94..c7ca5607481d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCasts.scala @@ -64,7 +64,7 @@ object CollationTypeCasts extends TypeCoercionRule { case otherExpr @ ( _: In | _: InSubquery | _: CreateArray | _: ArrayJoin | _: Concat | _: Greatest | _: Least | - _: Coalesce | _: BinaryExpression | _: ConcatWs) => + _: Coalesce | _: BinaryExpression | _: ConcatWs | _: Mask) => val newChildren = collateToSingleType(otherExpr.children) otherExpr.withNewChildren(newChildren) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 3eb6225b5426..c38b6cea9a0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ @@ -570,11 +571,12 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E override def second: Expression = pairDelim override def third: Expression = keyValueDelim - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation) - override def dataType: DataType = MapType(StringType, StringType) + override def dataType: DataType = MapType(first.dataType, first.dataType) - private lazy val mapBuilder = new ArrayBasedMapBuilder(StringType, StringType) + private lazy val mapBuilder = new ArrayBasedMapBuilder(first.dataType, first.dataType) override def nullSafeEval( inputString: Any, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala index e5157685a9a6..c11357352c79 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/maskExpressions.scala @@ -24,7 +24,9 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{FunctionSignature, InputParameter} import org.apache.spark.sql.errors.QueryErrorsBase -import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.types.{AbstractDataType, DataType} import org.apache.spark.unsafe.types.UTF8String // scalastyle:off line.size.limit @@ -79,12 +81,14 @@ import org.apache.spark.unsafe.types.UTF8String object MaskExpressionBuilder extends ExpressionBuilder { override def functionSignature: Option[FunctionSignature] = { val strArg = InputParameter("str") - val upperCharArg = InputParameter("upperChar", Some(Literal(Mask.MASKED_UPPERCASE))) - val lowerCharArg = InputParameter("lowerChar", Some(Literal(Mask.MASKED_LOWERCASE))) - val digitCharArg = InputParameter("digitChar", Some(Literal(Mask.MASKED_DIGIT))) - val otherCharArg = InputParameter( - "otherChar", - Some(Literal(Mask.MASKED_IGNORE, StringType))) + val upperCharArg = InputParameter("upperChar", + Some(Literal.create(Mask.MASKED_UPPERCASE, SQLConf.get.defaultStringType))) + val lowerCharArg = InputParameter("lowerChar", + Some(Literal.create(Mask.MASKED_LOWERCASE, SQLConf.get.defaultStringType))) + val digitCharArg = InputParameter("digitChar", + Some(Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType))) + val otherCharArg = InputParameter("otherChar", + Some(Literal.create(Mask.MASKED_IGNORE, SQLConf.get.defaultStringType))) val functionSignature: FunctionSignature = FunctionSignature(Seq( strArg, upperCharArg, lowerCharArg, digitCharArg, otherCharArg)) Some(functionSignature) @@ -109,33 +113,34 @@ case class Mask( def this(input: Expression) = this( input, - Literal(Mask.MASKED_UPPERCASE), - Literal(Mask.MASKED_LOWERCASE), - Literal(Mask.MASKED_DIGIT), - Literal(Mask.MASKED_IGNORE, StringType)) + Literal.create(Mask.MASKED_UPPERCASE, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_LOWERCASE, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_IGNORE, input.dataType)) def this(input: Expression, upperChar: Expression) = this( input, upperChar, - Literal(Mask.MASKED_LOWERCASE), - Literal(Mask.MASKED_DIGIT), - Literal(Mask.MASKED_IGNORE, StringType)) + Literal.create(Mask.MASKED_LOWERCASE, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_IGNORE, input.dataType)) def this(input: Expression, upperChar: Expression, lowerChar: Expression) = this( input, upperChar, lowerChar, - Literal(Mask.MASKED_DIGIT), - Literal(Mask.MASKED_IGNORE, StringType)) + Literal.create(Mask.MASKED_DIGIT, SQLConf.get.defaultStringType), + Literal.create(Mask.MASKED_IGNORE, input.dataType)) def this( input: Expression, upperChar: Expression, lowerChar: Expression, digitChar: Expression) = - this(input, upperChar, lowerChar, digitChar, Literal(Mask.MASKED_IGNORE, StringType)) + this(input, upperChar, lowerChar, digitChar, + Literal.create(Mask.MASKED_IGNORE, input.dataType)) override def checkInputDataTypes(): TypeCheckResult = { @@ -187,7 +192,8 @@ case class Mask( * NumericType, IntegralType, FractionalType. */ override def inputTypes: Seq[AbstractDataType] = - Seq(StringType, StringType, StringType, StringType, StringType) + Seq(StringTypeAnyCollation, StringTypeAnyCollation, StringTypeAnyCollation, + StringTypeAnyCollation, StringTypeAnyCollation) override def nullable: Boolean = true @@ -276,7 +282,7 @@ case class Mask( * Returns the [[DataType]] of the result of evaluating this expression. It is invalid to query * the dataType of an unresolved expression (i.e., when `resolved` == false). */ - override def dataType: DataType = StringType + override def dataType: DataType = input.dataType /** * Returns a Seq of the children of this node. Children should not change. Immutability required diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala new file mode 100644 index 000000000000..5cc0f568db77 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{MapType, StringType} + +// scalastyle:off nonascii +class CollationSQLExpressionsSuite + extends QueryTest + with SharedSparkSession { + + test("Support StringToMap expression with collation") { + // Supported collations + case class StringToMapTestCase[R](t: String, p: String, k: String, c: String, result: R) + val testCases = Seq( + StringToMapTestCase("a:1,b:2,c:3", ",", ":", "UTF8_BINARY", + Map("a" -> "1", "b" -> "2", "c" -> "3")), + StringToMapTestCase("A-1;B-2;C-3", ";", "-", "UTF8_BINARY_LCASE", + Map("A" -> "1", "B" -> "2", "C" -> "3")), + StringToMapTestCase("1:a,2:b,3:c", ",", ":", "UNICODE", + Map("1" -> "a", "2" -> "b", "3" -> "c")), + StringToMapTestCase("1/A!2/B!3/C", "!", "/", "UNICODE_CI", + Map("1" -> "A", "2" -> "B", "3" -> "C")) + ) + testCases.foreach(t => { + val query = s"SELECT str_to_map(collate('${t.t}', '${t.c}'), '${t.p}', '${t.k}');" + // Result & data type + checkAnswer(sql(query), Row(t.result)) + val dataType = MapType(StringType(t.c), StringType(t.c), true) + assert(sql(query).schema.fields.head.dataType.sameType(dataType)) + }) + } + + test("Support Mask expression with collation") { + // Supported collations + case class MaskTestCase[R](i: String, u: String, l: String, d: String, o: String, c: String, + result: R) + val testCases = Seq( + MaskTestCase("ab-CD-12-@$", null, null, null, null, "UTF8_BINARY", "ab-CD-12-@$"), + MaskTestCase("ab-CD-12-@$", "X", null, null, null, "UTF8_BINARY_LCASE", "ab-XX-12-@$"), + MaskTestCase("ab-CD-12-@$", "X", "x", null, null, "UNICODE", "xx-XX-12-@$"), + MaskTestCase("ab-CD-12-@$", "X", "x", "0", "#", "UNICODE_CI", "xx#XX#00###") + ) + testCases.foreach(t => { + def col(s: String): String = if (s == null) "null" else s"collate('$s', '${t.c}')" + val query = s"SELECT mask(${col(t.i)}, ${col(t.u)}, ${col(t.l)}, ${col(t.d)}, ${col(t.o)})" + // Result & data type + var result = sql(query) + checkAnswer(result, Row(t.result)) + assert(result.schema.fields.head.dataType.sameType(StringType(t.c))) + }) + // Implicit casting + val testCasting = Seq( + MaskTestCase("ab-CD-12-@$", "X", "x", "0", "#", "UNICODE_CI", "xx#XX#00###") + ) + testCasting.foreach(t => { + def col(s: String): String = if (s == null) "null" else s"collate('$s', '${t.c}')" + def str(s: String): String = if (s == null) "null" else s"'$s'" + val query1 = s"SELECT mask(${col(t.i)}, ${str(t.u)}, ${str(t.l)}, ${str(t.d)}, ${str(t.o)})" + val query2 = s"SELECT mask(${str(t.i)}, ${col(t.u)}, ${str(t.l)}, ${str(t.d)}, ${str(t.o)})" + val query3 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${col(t.l)}, ${str(t.d)}, ${str(t.o)})" + val query4 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${str(t.l)}, ${col(t.d)}, ${str(t.o)})" + val query5 = s"SELECT mask(${str(t.i)}, ${str(t.u)}, ${str(t.l)}, ${str(t.d)}, ${col(t.o)})" + for (q <- Seq(query1, query2, query3, query4, query5)) { + val result = sql(q) + checkAnswer(result, Row(t.result)) + assert(result.schema.fields.head.dataType.sameType(StringType(t.c))) + } + }) + // Collation mismatch + val collationMismatch = intercept[AnalysisException] { + sql("SELECT mask(collate('ab-CD-12-@$','UNICODE'),collate('X','UNICODE_CI'),'x','0','#')") + } + assert(collationMismatch.getErrorClass === "COLLATION_MISMATCH.EXPLICIT") + } + + // TODO: Add more tests for other SQL expressions + +} +// scalastyle:on nonascii --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org