This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 723354039f1d [SPARK-48162][SQL] Add collation support for MISC expressions 723354039f1d is described below commit 723354039f1de587cacdf4ba48c076a896fdffd1 Author: Uros Bojanic <157381213+uros...@users.noreply.github.com> AuthorDate: Wed May 15 14:23:31 2024 +0800 [SPARK-48162][SQL] Add collation support for MISC expressions ### What changes were proposed in this pull request? Introduce collation awareness for misc expressions: raise_error, uuid, version, typeof, aes_encrypt, aes_decrypt. ### Why are the changes needed? Add collation support for misc expressions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use collated strings within arguments for misc functions: raise_error, uuid, version, typeof, aes_encrypt, aes_decrypt. ### How was this patch tested? E2e sql tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #46461 from uros-db/misc-expressions. Authored-by: Uros Bojanic <157381213+uros...@users.noreply.github.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../explain-results/function_aes_decrypt.explain | 2 +- .../function_aes_decrypt_with_mode.explain | 2 +- .../function_aes_decrypt_with_mode_padding.explain | 2 +- ...ction_aes_decrypt_with_mode_padding_aad.explain | 2 +- .../explain-results/function_aes_encrypt.explain | 2 +- .../function_aes_encrypt_with_mode.explain | 2 +- .../function_aes_encrypt_with_mode_padding.explain | 2 +- ...nction_aes_encrypt_with_mode_padding_iv.explain | 2 +- ...on_aes_encrypt_with_mode_padding_iv_aad.explain | 2 +- .../function_try_aes_decrypt.explain | 2 +- .../function_try_aes_decrypt_with_mode.explain | 2 +- ...ction_try_aes_decrypt_with_mode_padding.explain | 2 +- ...n_try_aes_decrypt_with_mode_padding_aad.explain | 2 +- .../spark/sql/catalyst/expressions/misc.scala | 14 ++- .../spark/sql/CollationSQLExpressionsSuite.scala | 136 +++++++++++++++++++++ 15 files changed, 157 insertions(+), 19 deletions(-) diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain index 31e03b79eb98..55f1c314671a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, GCM, DEFAULT, )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, GCM, DEFAULT, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain index fc572e8fe7c6..762a4f47a058 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, DEFAULT, )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, DEFAULT, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain index c6c693013dd0..7c31c1754c3b 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain index 97bb528b84b3..48b640efb376 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_decrypt_with_mode_padding_aad.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true) AS aes_decrypt(g, g, g, g, g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt.explain index 44084a8e60fb..d88a71848572 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, GCM, DEFAULT, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, GCM, DEFAULT, , )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode.explain index 29ccf0c1c833..59fb110a8359 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, DEFAULT, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, DEFAULT, , )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding.explain index 5591363426ab..80912e43353c 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, , )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, , )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv.explain index 54b08d7bdb48..6d61e3c7d097 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', )#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv_aad.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv_aad.explain index 024089170bc7..9d0bdb901d7e 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv_aad.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_aes_encrypt_with_mode_padding_iv_aad.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesEncrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, 0x434445, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, BinaryType, true, true, true) AS aes_encrypt(g, g, g, g, X'434445', g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt.explain index b45be2845308..56d4c6eb0e0a 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt.explain @@ -1,2 +1,2 @@ -Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, GCM, DEFAULT, )#0] +Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), GCM, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, GCM, DEFAULT, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode.explain index 82b7ed1ea893..6b46dbd067ad 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode.explain @@ -1,2 +1,2 @@ -Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, DEFAULT, )#0] +Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, DEFAULT, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, DEFAULT, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding.explain index 9087d743d941..9436cc826022 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding.explain @@ -1,2 +1,2 @@ -Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, )#0] +Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast( as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, )#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding_aad.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding_aad.explain index 8854da9b423d..c8182e3b05dd 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding_aad.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_try_aes_decrypt_with_mode_padding_aad.explain @@ -1,2 +1,2 @@ -Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringType, StringType, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, g)#0] +Project [tryeval(staticinvoke(class org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils, BinaryType, aesDecrypt, cast(g#0 as binary), cast(g#0 as binary), g#0, g#0, cast(g#0 as binary), BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType, true, true, true)) AS try_aes_decrypt(g, g, g, g, g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index c7281e4e8737..eda65ae48f00 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.util.{MapData, RandomUUIDGenerator} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors.raiseError import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.types.StringTypeAnyCollation import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -84,7 +85,7 @@ case class RaiseError(errorClass: Expression, errorParms: Expression, dataType: override def foldable: Boolean = false override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = - Seq(StringType, MapType(StringType, StringType)) + Seq(StringTypeAnyCollation, MapType(StringType, StringType)) override def left: Expression = errorClass override def right: Expression = errorParms @@ -251,7 +252,7 @@ case class Uuid(randomSeed: Option[Long] = None) extends LeafExpression with Non override def nullable: Boolean = false - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def stateful: Boolean = true @@ -292,7 +293,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable { override lazy val replacement: Expression = StaticInvoke( classOf[ExpressionImplUtils], - StringType, + SQLConf.get.defaultStringType, "getSparkVersion", returnNullable = false) } @@ -311,7 +312,7 @@ case class SparkVersion() extends LeafExpression with RuntimeReplaceable { case class TypeOf(child: Expression) extends UnaryExpression { override def nullable: Boolean = false override def foldable: Boolean = true - override def dataType: DataType = StringType + override def dataType: DataType = SQLConf.get.defaultStringType override def eval(input: InternalRow): Any = UTF8String.fromString(child.dataType.catalogString) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -412,7 +413,8 @@ case class AesEncrypt( override def prettyName: String = "aes_encrypt" override def inputTypes: Seq[AbstractDataType] = - Seq(BinaryType, BinaryType, StringType, StringType, BinaryType, BinaryType) + Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, + BinaryType, BinaryType) override def children: Seq[Expression] = Seq(input, key, mode, padding, iv, aad) @@ -485,7 +487,7 @@ case class AesDecrypt( this(input, key, Literal("GCM")) override def inputTypes: Seq[AbstractDataType] = { - Seq(BinaryType, BinaryType, StringType, StringType, BinaryType) + Seq(BinaryType, BinaryType, StringTypeAnyCollation, StringTypeAnyCollation, BinaryType) } override def prettyName: String = "aes_decrypt" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index f8b3548b956c..48c3853bb5cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -931,6 +931,142 @@ class CollationSQLExpressionsSuite }) } + test("Support RaiseError misc expression with collation") { + // Supported collations + case class RaiseErrorTestCase(errorMessage: String, collationName: String) + val testCases = Seq( + RaiseErrorTestCase("custom error message 1", "UTF8_BINARY"), + RaiseErrorTestCase("custom error message 2", "UTF8_BINARY_LCASE"), + RaiseErrorTestCase("custom error message 3", "UNICODE"), + RaiseErrorTestCase("custom error message 4", "UNICODE_CI") + ) + testCases.foreach(t => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val query = s"SELECT raise_error('${t.errorMessage}')" + // Result & data type + val userException = intercept[SparkRuntimeException] { + sql(query).collect() + } + assert(userException.getErrorClass === "USER_RAISED_EXCEPTION") + assert(userException.getMessage.contains(t.errorMessage)) + } + }) + } + + test("Support Uuid misc expression with collation") { + // Supported collations + Seq("UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val query = s"SELECT uuid()" + // Result & data type + val testQuery = sql(query) + val queryResult = testQuery.collect().head.getString(0) + val uuidFormat = "^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + assert(queryResult.matches(uuidFormat)) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + ) + } + + test("Support SparkVersion misc expression with collation") { + // Supported collations + Seq("UTF8_BINARY", "UTF8_BINARY_LCASE", "UNICODE", "UNICODE_CI").foreach(collationName => + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collationName) { + val query = s"SELECT version()" + // Result & data type + val testQuery = sql(query) + val queryResult = testQuery.collect().head.getString(0) + val versionFormat = "^[0-9]\\.[0-9]\\.[0-9] [0-9a-f]{40}$" + assert(queryResult.matches(versionFormat)) + val dataType = StringType(collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + ) + } + + test("Support TypeOf misc expression with collation") { + // Supported collations + case class TypeOfTestCase(input: String, collationName: String, result: String) + val testCases = Seq( + TypeOfTestCase("1", "UTF8_BINARY", "int"), + TypeOfTestCase("\"A\"", "UTF8_BINARY_LCASE", "string collate UTF8_BINARY_LCASE"), + TypeOfTestCase("array(1)", "UNICODE", "array<int>"), + TypeOfTestCase("null", "UNICODE_CI", "void") + ) + testCases.foreach(t => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val query = s"SELECT typeof(${t.input})" + // Result & data type + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support AesEncrypt misc expression with collation") { + // Supported collations + case class AesEncryptTestCase( + input: String, + collationName: String, + params: String, + result: String + ) + val testCases = Seq( + AesEncryptTestCase("Spark", "UTF8_BINARY", "'1234567890abcdef', 'ECB'", + "8DE7DB79A23F3E8ED530994DDEA98913"), + AesEncryptTestCase("Spark", "UTF8_BINARY_LCASE", "'1234567890abcdef', 'ECB', 'DEFAULT', ''", + "8DE7DB79A23F3E8ED530994DDEA98913"), + AesEncryptTestCase("Spark", "UNICODE", "'1234567890abcdef', 'GCM', 'DEFAULT', " + + "unhex('000000000000000000000000')", + "00000000000000000000000046596B2DE09C729FE48A0F81A00A4E7101DABEB61D"), + AesEncryptTestCase("Spark", "UNICODE_CI", "'1234567890abcdef', 'CBC', 'DEFAULT', " + + "unhex('00000000000000000000000000000000')", + "000000000000000000000000000000008DE7DB79A23F3E8ED530994DDEA98913") + ) + testCases.foreach(t => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val query = s"SELECT hex(aes_encrypt('${t.input}', ${t.params}))" + // Result & data type + val testQuery = sql(query) + checkAnswer(testQuery, Row(t.result)) + val dataType = StringType(t.collationName) + assert(testQuery.schema.fields.head.dataType.sameType(dataType)) + } + }) + } + + test("Support AesDecrypt misc expression with collation") { + // Supported collations + case class AesDecryptTestCase( + input: String, + collationName: String, + params: String, + result: String + ) + val testCases = Seq( + AesDecryptTestCase("8DE7DB79A23F3E8ED530994DDEA98913", + "UTF8_BINARY", "'1234567890abcdef', 'ECB'", "Spark"), + AesDecryptTestCase("8DE7DB79A23F3E8ED530994DDEA98913", + "UTF8_BINARY_LCASE", "'1234567890abcdef', 'ECB', 'DEFAULT', ''", "Spark"), + AesDecryptTestCase("00000000000000000000000046596B2DE09C729FE48A0F81A00A4E7101DABEB61D", + "UNICODE", "'1234567890abcdef', 'GCM', 'DEFAULT'", "Spark"), + AesDecryptTestCase("000000000000000000000000000000008DE7DB79A23F3E8ED530994DDEA98913", + "UNICODE_CI", "'1234567890abcdef', 'CBC', 'DEFAULT'", "Spark") + ) + testCases.foreach(t => { + withSQLConf(SqlApiConf.DEFAULT_COLLATION -> t.collationName) { + val query = s"SELECT aes_decrypt(unhex('${t.input}'), ${t.params})" + // Result & data type + val testQuery = sql(query) + checkAnswer(testQuery, sql(s"SELECT to_binary('${t.result}', 'utf-8')")) + assert(testQuery.schema.fields.head.dataType.sameType(BinaryType)) + } + }) + } + test("Support Mask expression with collation") { // Supported collations case class MaskTestCase[R](i: String, u: String, l: String, d: String, o: String, c: String, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org