This is an automated email from the ASF dual-hosted git repository. viirya pushed a commit to branch branch-2.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push: new 7733510 [SPARK-35288][SQL] StaticInvoke should find the method without exact argument classes match 7733510 is described below commit 7733510d0403625c41710d7e79f810117aac2ced Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Fri May 7 09:07:57 2021 -0700 [SPARK-35288][SQL] StaticInvoke should find the method without exact argument classes match ### What changes were proposed in this pull request? This patch proposes to make StaticInvoke able to find method with given method name even the parameter types do not exactly match to argument classes. ### Why are the changes needed? Unlike `Invoke`, `StaticInvoke` only tries to get the method with exact argument classes. If the calling method's parameter types are not exactly matched with the argument classes, `StaticInvoke` cannot find the method. `StaticInvoke` should be able to find the method under the cases too. ### Does this PR introduce _any_ user-facing change? Yes. `StaticInvoke` can find a method even the argument classes are not exactly matched. ### How was this patch tested? Unit test. Closes #32413 from viirya/static-invoke. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> (cherry picked from commit 33fbf5647b4a5587c78ac51339c0cbc9d70547a4) Signed-off-by: Liang-Chi Hsieh <vii...@gmail.com> --- .../sql/catalyst/expressions/objects/objects.scala | 56 ++++++++++++---------- .../expressions/ObjectExpressionsSuite.scala | 34 +++++++++++-- 2 files changed, 60 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 43e8105..fb4132a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -139,6 +139,34 @@ trait InvokeLike extends Expression with NonSQLExpression { } } } + + final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = { + // Looking with function name + argument classes first. + try { + cls.getMethod(functionName, argClasses: _*) + } catch { + case _: NoSuchMethodException => + // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. + // We look at function name + argument length + val m = cls.getMethods.filter { m => + m.getName == functionName && m.getParameterCount == arguments.length + } + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else if (m.length > 1) { + // More than one matched method signature. Exclude synthetic one, e.g. generic one. + val realMethods = m.filter(!_.isSynthetic) + if (realMethods.length > 1) { + // Ambiguous case, we don't know which method to choose, just fail it. + sys.error(s"Found ${realMethods.length} $functionName on $cls") + } else { + realMethods.head + } + } else { + m.head + } + } + } } /** @@ -230,7 +258,7 @@ case class StaticInvoke( override def children: Seq[Expression] = arguments lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments) - @transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*) + @transient lazy val method = findMethod(cls, functionName, argClasses) override def eval(input: InternalRow): Any = { invoke(null, method, arguments, input, dataType) @@ -317,31 +345,7 @@ case class Invoke( @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - // Looking with function name + argument classes first. - try { - Some(cls.getMethod(encodedFunctionName, argClasses: _*)) - } catch { - case _: NoSuchMethodException => - // For some cases, e.g. arg class is Object, `getMethod` cannot find the method. - // We look at function name + argument length - val m = cls.getMethods.filter { m => - m.getName == encodedFunctionName && m.getParameterCount == arguments.length - } - if (m.isEmpty) { - sys.error(s"Couldn't find $encodedFunctionName on $cls") - } else if (m.length > 1) { - // More than one matched method signature. Exclude synthetic one, e.g. generic one. - val realMethods = m.filter(!_.isSynthetic) - if (realMethods.length > 1) { - // Ambiguous case, we don't know which method to choose, just fail it. - sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls") - } else { - Some(realMethods.head) - } - } else { - Some(m.head) - } - } + Some(findMethod(cls, encodedFunctionName, argClasses)) case _ => None } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index a80c659..ac1f35b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -612,8 +612,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val clsType = ObjectType(classOf[ConcreteClass]) val obj = new ConcreteClass + val input = (1, 2) checkObjectExprEvaluation( - Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0) + Invoke(Literal(obj, clsType), "testFunc", IntegerType, + Seq(Literal(input, ObjectType(input.getClass)))), 2) + } + + test("SPARK-35288: static invoke should find method without exact param type match") { + val input = (1, 2) + + checkObjectExprEvaluation( + StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func", + Seq(Literal(input, ObjectType(input.getClass)))), 3) + + checkObjectExprEvaluation( + StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func", + Seq(Literal(1, IntegerType))), -1) } } @@ -626,10 +640,22 @@ class TestBean extends Serializable { assert(i != null, "this setter should not be called with null.") } +object TestStaticInvoke { + def func(param: Any): Int = param match { + case pair: Tuple2[_, _] => + pair.asInstanceOf[Tuple2[Int, Int]]._1 + pair.asInstanceOf[Tuple2[Int, Int]]._2 + case _ => -1 + } +} + abstract class BaseClass[T] { - def testFunc(param: T): T + def testFunc(param: T): Int } -class ConcreteClass extends BaseClass[Int] with Serializable { - override def testFunc(param: Int): Int = param - 1 +class ConcreteClass extends BaseClass[Product] with Serializable { + override def testFunc(param: Product): Int = param match { + case _: Tuple2[_, _] => 2 + case _: Tuple3[_, _, _] => 3 + case _ => 4 + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org