This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3e82ac6ea3d [SPARK-44391][SQL] Check the number of argument types in 
`InvokeLike`
3e82ac6ea3d is described below

commit 3e82ac6ea3d9f87c8ac09e481235beefaa1bf758
Author: Max Gekk <max.g...@gmail.com>
AuthorDate: Thu Jul 13 12:17:20 2023 +0300

    [SPARK-44391][SQL] Check the number of argument types in `InvokeLike`
    
    ### What changes were proposed in this pull request?
    In the PR, I propose to check the number of argument types in the 
`InvokeLike` expressions. If the input types are provided, the number of types 
should be exactly the same as the number of argument expressions.
    
    ### Why are the changes needed?
    1. This PR checks the contract described in the comment explicitly:
    
https://github.com/apache/spark/blob/d9248e83bbb3af49333608bebe7149b1aaeca738/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala#L247
    
    that can prevent the errors of expression implementations, and improve code 
maintainability.
    
    2. Also it fixes the issue in the `UrlEncode` and `UrlDecode`.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    By running the related tests:
    ```
    $ build/sbt "test:testOnly *UrlFunctionsSuite"
    $ build/sbt "test:testOnly *DataSourceV2FunctionSuite"
    ```
    
    Closes #41954 from MaxGekk/fix-url_decode.
    
    Authored-by: Max Gekk <max.g...@gmail.com>
    Signed-off-by: Max Gekk <max.g...@gmail.com>
---
 common/utils/src/main/resources/error/error-classes.json  |  5 +++++
 .../explain-results/function_url_decode.explain           |  2 +-
 .../explain-results/function_url_encode.explain           |  2 +-
 .../sql-error-conditions-datatype-mismatch-error-class.md |  4 ++++
 .../spark/sql/catalyst/analysis/CheckAnalysis.scala       |  5 +++--
 .../spark/sql/catalyst/expressions/objects/objects.scala  | 15 +++++++++++++++
 .../spark/sql/catalyst/expressions/urlExpressions.scala   |  4 ++--
 7 files changed, 31 insertions(+), 6 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-classes.json 
b/common/utils/src/main/resources/error/error-classes.json
index 347ce026476..2c4d2b533a6 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -657,6 +657,11 @@
           "The <exprName> must be between <valueRange> (current value = 
<currentValue>)."
         ]
       },
+      "WRONG_NUM_ARG_TYPES" : {
+        "message" : [
+          "The expression requires <expectedNum> argument types but the actual 
number is <actualNum>."
+        ]
+      },
       "WRONG_NUM_ENDPOINTS" : {
         "message" : [
           "The number of endpoints must be >= 2 to construct intervals but the 
actual number is <actualNumber>."
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain
index 36b21e27c10..d612190396d 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_decode.explain
@@ -1,2 +1,2 @@
-Project [staticinvoke(class 
org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, 
UTF-8, StringType, true, true, true) AS url_decode(g)#0]
+Project [staticinvoke(class 
org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, decode, g#0, 
UTF-8, StringType, StringType, true, true, true) AS url_decode(g)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain
index 70a0f628fc9..bd2c63e19c6 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_url_encode.explain
@@ -1,2 +1,2 @@
-Project [staticinvoke(class 
org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, 
UTF-8, StringType, true, true, true) AS url_encode(g)#0]
+Project [staticinvoke(class 
org.apache.spark.sql.catalyst.expressions.UrlCodec$, StringType, encode, g#0, 
UTF-8, StringType, StringType, true, true, true) AS url_encode(g)#0]
 +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/docs/sql-error-conditions-datatype-mismatch-error-class.md 
b/docs/sql-error-conditions-datatype-mismatch-error-class.md
index 3bd63925323..ddc3e0c2b1b 100644
--- a/docs/sql-error-conditions-datatype-mismatch-error-class.md
+++ b/docs/sql-error-conditions-datatype-mismatch-error-class.md
@@ -234,6 +234,10 @@ The input of `<functionName>` can't be `<dataType>` type 
data.
 
 The `<exprName>` must be between `<valueRange>` (current value = 
`<currentValue>`).
 
+## WRONG_NUM_ARG_TYPES
+
+The expression requires `<expectedNum>` argument types but the actual number 
is `<actualNum>`.
+
 ## WRONG_NUM_ENDPOINTS
 
 The number of endpoints must be >= 2 to construct intervals but the actual 
number is `<actualNumber>`.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 852055c4df1..7085a040d66 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -295,8 +295,9 @@ trait CheckAnalysis extends PredicateHelper with 
LookupCatalog with QueryErrorsB
               context = c.origin.getQueryContext,
               summary = c.origin.context.summary)
           case e: RuntimeReplaceable if !e.replacement.resolved =>
-            throw new IllegalStateException("Illegal RuntimeReplaceable: " + e 
+
-              "\nReplacement is unresolved: " + e.replacement)
+            throw SparkException.internalError(
+              s"Cannot resolve the runtime replaceable expression 
${toSQLExpr(e)}. " +
+              s"The replacement is unresolved: ${toSQLExpr(e.replacement)}.")
 
           case g: Grouping =>
             g.failAnalysis(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d4c5428af4d..fec60aef1bf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -31,6 +31,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
 import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{InternalRow, ScalaReflection}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.encoders.EncoderUtils
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
@@ -48,6 +49,19 @@ import org.apache.spark.util.Utils
 trait InvokeLike extends Expression with NonSQLExpression with 
ImplicitCastInputTypes {
 
   def arguments: Seq[Expression]
+  protected def argumentTypes: Seq[AbstractDataType] = inputTypes
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (!argumentTypes.isEmpty && argumentTypes.length != arguments.length) {
+      TypeCheckResult.DataTypeMismatch(
+        errorSubClass = "WRONG_NUM_ARG_TYPES",
+        messageParameters = Map(
+          "expectedNum" -> arguments.length.toString,
+          "actualNum" -> argumentTypes.length.toString))
+    } else {
+      super.checkInputDataTypes()
+    }
+  }
 
   def propagateNull: Boolean
 
@@ -384,6 +398,7 @@ case class Invoke(
     } else {
       Nil
     }
+  override protected def argumentTypes: Seq[AbstractDataType] = 
methodInputTypes
 
   private lazy val encodedFunctionName = 
ScalaReflection.encodeFieldNameToIdentifier(functionName)
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala
index b3ba5656d44..47b37a5edeb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/urlExpressions.scala
@@ -57,7 +57,7 @@ case class UrlEncode(child: Expression)
       StringType,
       "encode",
       Seq(child, Literal("UTF-8")),
-      Seq(StringType))
+      Seq(StringType, StringType))
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression = {
     copy(child = newChild)
@@ -94,7 +94,7 @@ case class UrlDecode(child: Expression)
       StringType,
       "decode",
       Seq(child, Literal("UTF-8")),
-      Seq(StringType))
+      Seq(StringType, StringType))
 
   override protected def withNewChildInternal(newChild: Expression): 
Expression = {
     copy(child = newChild)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to