This is an automated email from the ASF dual-hosted git repository. maxgekk pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 3e82ac6ea3d [SPARK-44391][SQL] Check the number of argument types in `InvokeLike` 3e82ac6ea3d is described below commit 3e82ac6ea3d9f87c8ac09e481235beefaa1bf758 Author: Max Gekk <max.g...@gmail.com> AuthorDate: Thu Jul 13 12:17:20 2023 +0300 [SPARK-44391][SQL] Check the number of argument types in `InvokeLike` ### What changes were proposed in this pull request? In the PR, I propose to check the number of argument types in the `InvokeLike` expressions. If the input types are provided, the number of types should be exactly the same as the number of argument expressions. ### Why are the changes needed? 1. This PR checks the contract described in the comment explicitly: https://github.com/apache/spark/blob/d9248e83bbb3af49333608bebe7149b1aaeca738/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L247 that can prevent the errors of expression implementations, and improve code maintainability. 2. Also it fixes the issue in the `UrlEncode` and `UrlDecode`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By running the related tests: ``` $ build/sbt "test:testOnly *UrlFunctionsSuite" $ build/sbt "test:testOnly *DataSourceV2FunctionSuite" ``` Closes #41954 from MaxGekk/fix-url_decode. Authored-by: Max Gekk <max.g...@gmail.com> Signed-off-by: Max Gekk <max.g...@gmail.com> --- common/utils/src/main/resources/error/error-classes.json | 5 +++++ .../explain-results/function_url_decode.explain | 2 +- .../explain-results/function_url_encode.explain | 2 +- .../sql-error-conditions-datatype-mismatch-error-class.md | 4 ++++ .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 5 +++-- .../spark/sql/catalyst/expressions/objects/objects.scala | 15 +++++++++++++++ .../spark/sql/catalyst/expressions/urlExpressions.scala | 4 ++-- 7 files changed, 31 insertions(+), 6 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 347ce026476..2c4d2b533a6 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -657,6 +657,11 @@ "The <exprName> must be between <valueRange> (current value = <currentValue>)." ] }, + "WRONG_NUM_ARG_TYPES" : { + "message" : [ + "The expression requires <expectedNum> argument types but the actual number is <actualNum>." + ] + }, "WRONG_NUM_ENDPOINTS" : { "message" : [ "The number of endpoints must be >= 2 to construct intervals but the actual number is <actualNumber>." diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain index 36b21e27c10..d612190396d 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, UTF-8, StringType, true, true, true) AS url_decode(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, UTF-8, StringType, StringType, true, true, true) AS url_decode(g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain index 70a0f628fc9..bd2c63e19c6 100644 --- a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain @@ -1,2 +1,2 @@ -Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, UTF-8, StringType, true, true, true) AS url_encode(g)#0] +Project [staticinvoke(class org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, UTF-8, StringType, StringType, true, true, true) AS url_encode(g)#0] +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/docs/sql-error-conditions-datatype-mismatch-error-class.md b/docs/sql-error-conditions-datatype-mismatch-error-class.md index 3bd63925323..ddc3e0c2b1b 100644 --- a/docs/sql-error-conditions-datatype-mismatch-error-class.md +++ b/docs/sql-error-conditions-datatype-mismatch-error-class.md @@ -234,6 +234,10 @@ The input of `<functionName>` can't be `<dataType>` type data. The `<exprName>` must be between `<valueRange>` (current value = `<currentValue>`). +## WRONG_NUM_ARG_TYPES + +The expression requires `<expectedNum>` argument types but the actual number is `<actualNum>`. + ## WRONG_NUM_ENDPOINTS The number of endpoints must be >= 2 to construct intervals but the actual number is `<actualNumber>`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 852055c4df1..7085a040d66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -295,8 +295,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB context = c.origin.getQueryContext, summary = c.origin.context.summary) case e: RuntimeReplaceable if !e.replacement.resolved => - throw new IllegalStateException("Illegal RuntimeReplaceable: " + e + - "\nReplacement is unresolved: " + e.replacement) + throw SparkException.internalError( + s"Cannot resolve the runtime replaceable expression ${toSQLExpr(e)}. " + + s"The replacement is unresolved: ${toSQLExpr(e.replacement)}.") case g: Grouping => g.failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d4c5428af4d..fec60aef1bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -31,6 +31,7 @@ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.encoders.EncoderUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -48,6 +49,19 @@ import org.apache.spark.util.Utils trait InvokeLike extends Expression with NonSQLExpression with ImplicitCastInputTypes { def arguments: Seq[Expression] + protected def argumentTypes: Seq[AbstractDataType] = inputTypes + + override def checkInputDataTypes(): TypeCheckResult = { + if (!argumentTypes.isEmpty && argumentTypes.length != arguments.length) { + TypeCheckResult.DataTypeMismatch( + errorSubClass = "WRONG_NUM_ARG_TYPES", + messageParameters = Map( + "expectedNum" -> arguments.length.toString, + "actualNum" -> argumentTypes.length.toString)) + } else { + super.checkInputDataTypes() + } + } def propagateNull: Boolean @@ -384,6 +398,7 @@ case class Invoke( } else { Nil } + override protected def argumentTypes: Seq[AbstractDataType] = methodInputTypes private lazy val encodedFunctionName = ScalaReflection.encodeFieldNameToIdentifier(functionName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala index b3ba5656d44..47b37a5edeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala @@ -57,7 +57,7 @@ case class UrlEncode(child: Expression) StringType, "encode", Seq(child, Literal("UTF-8")), - Seq(StringType)) + Seq(StringType, StringType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) @@ -94,7 +94,7 @@ case class UrlDecode(child: Expression) StringType, "decode", Seq(child, Literal("UTF-8")), - Seq(StringType)) + Seq(StringType, StringType)) override protected def withNewChildInternal(newChild: Expression): Expression = { copy(child = newChild) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org