This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 44d762a  [SPARK-35389][SQL] V2 ScalarFunction should support magic 
method with null arguments
44d762a is described below

commit 44d762abc6395570f1f493a145fd5d1cbdf0b49e
Author: Chao Sun <sunc...@apple.com>
AuthorDate: Tue May 18 08:45:55 2021 +0000

    [SPARK-35389][SQL] V2 ScalarFunction should support magic method with null 
arguments
    
    ### What changes were proposed in this pull request?
    
    When creating `Invoke` and `StaticInvoke` for `ScalarFunction`'s magic 
method, set `propagateNull` to false.
    
    ### Why are the changes needed?
    
    When `propgagateNull` is true (which is the default value), `Invoke` and 
`StaticInvoke` will return null if any of the argument is null. For scalar 
function this is incorrect, as we should leave the logic to function 
implementation instead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Now null arguments shall be properly handled with magic method.
    
    ### How was this patch tested?
    
    Added new tests.
    
    Closes #32553 from sunchao/SPARK-35389.
    
    Authored-by: Chao Sun <sunc...@apple.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../catalog/functions/ScalarFunction.java          | 19 +++++++++++
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  5 +--
 .../sql/catalyst/expressions/objects/objects.scala | 26 +++++++++++----
 .../connector/catalog/functions/JavaStrLen.java    | 19 +++++++++++
 .../sql/connector/DataSourceV2FunctionSuite.scala  | 37 +++++++++++++++++++++-
 5 files changed, 96 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
index 858ab92..d261a24 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types.DataType;
  * InternalRow API for the {@link DataType SQL data type} returned by {@link 
#resultType()}.
  * The mapping between {@link DataType} and the corresponding JVM type is 
defined below.
  * <p>
+ * <h2> Magic method </h2>
  * <b>IMPORTANT</b>: the default implementation of {@link #produceResult} 
throws
  * {@link UnsupportedOperationException}. Users must choose to either override 
this method, or
  * implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes 
individual parameters
@@ -82,6 +83,24 @@ import org.apache.spark.sql.types.DataType;
  * following the mapping defined below, and then checking if there is a 
matching method from all the
  * declared methods in the UDF class, using method name and the Java types.
  * <p>
+ * <h2> Handling of nullable primitive arguments </h2>
+ * The handling of null primitive arguments is different between the magic 
method approach and
+ * the {@link #produceResult} approach. With the former, whenever any of the 
method arguments meet
+ * the following conditions:
+ * <ol>
+ *   <li>the argument is of primitive type</li>
+ *   <li>the argument is nullable</li>
+ *   <li>the value of the argument is null</li>
+ * </ol>
+ * Spark will return null directly instead of calling the magic method. On the 
other hand, Spark
+ * will pass null primitive arguments to {@link #produceResult} and it is 
user's responsibility to
+ * handle them in the function implementation.
+ * <p>
+ * Because of the difference, if Spark users want to implement special 
handling of nulls for
+ * nullable primitive arguments, they should override the {@link 
#produceResult} method instead
+ * of using the magic method approach.
+ * <p>
+ * <h2> Spark data type to Java type mapping </h2>
  * The following are the mapping from {@link DataType SQL data type} to Java 
type which is used
  * by Spark to infer parameter types for the magic methods as well as return 
value type for
  * {@link #produceResult}:
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9954ca0..3f2e93a 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2204,11 +2204,12 @@ class Analyzer(override val catalogManager: 
CatalogManager)
         findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
           case Some(m) if Modifier.isStatic(m.getModifiers) =>
             StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
-              MAGIC_METHOD_NAME, arguments, returnNullable = 
scalarFunc.isResultNullable)
+              MAGIC_METHOD_NAME, arguments, propagateNull = false,
+              returnNullable = scalarFunc.isResultNullable)
           case Some(_) =>
             val caller = Literal.create(scalarFunc, 
ObjectType(scalarFunc.getClass))
             Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
-              arguments, returnNullable = scalarFunc.isResultNullable)
+              arguments, propagateNull = false, returnNullable = 
scalarFunc.isResultNullable)
           case _ =>
             // TODO: handle functions defined in Scala too - in Scala, even if 
a
             //  subclass do not override the default method in parent interface
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e871c30..c88f785 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -50,7 +50,10 @@ trait InvokeLike extends Expression with NonSQLExpression {
 
   def propagateNull: Boolean
 
-  protected lazy val needNullCheck: Boolean = propagateNull && 
arguments.exists(_.nullable)
+  protected lazy val needNullCheck: Boolean = 
needNullCheckForIndex.contains(true)
+  protected lazy val needNullCheckForIndex: Array[Boolean] =
+    arguments.map(a => a.nullable && (propagateNull ||
+        ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray
   protected lazy val evaluatedArgs: Array[Object] = new 
Array[Object](arguments.length)
   private lazy val boxingFn: Any => Any =
     ScalaReflection.typeBoxedJavaMapping
@@ -89,7 +92,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
       val reset = s"$resultIsNull = false;"
       val argCodes = arguments.zipWithIndex.map { case (e, i) =>
         val expr = e.genCode(ctx)
-        val updateResultIsNull = if (e.nullable) {
+        val updateResultIsNull = if (needNullCheckForIndex(i)) {
           s"$resultIsNull = ${expr.isNull};"
         } else {
           ""
@@ -131,11 +134,14 @@ trait InvokeLike extends Expression with NonSQLExpression 
{
   def invoke(obj: Any, method: Method, input: InternalRow): Any = {
     var i = 0
     val len = arguments.length
+    var resultNull = false
     while (i < len) {
-      evaluatedArgs(i) = arguments(i).eval(input).asInstanceOf[Object]
+      val result = arguments(i).eval(input).asInstanceOf[Object]
+      evaluatedArgs(i) = result
+      resultNull = resultNull || (result == null && needNullCheckForIndex(i))
       i += 1
     }
-    if (needNullCheck && evaluatedArgs.contains(null)) {
+    if (needNullCheck && resultNull) {
       // return null if one of arguments is null
       null
     } else {
@@ -226,7 +232,9 @@ object SerializerSupport {
  * @param functionName The name of the method to call.
  * @param arguments An optional list of expressions to pass as arguments to 
the function.
  * @param propagateNull When true, and any of the arguments is null, null will 
be returned instead
- *                      of calling the function.
+ *                      of calling the function. Also note: when this is false 
but any of the
+ *                      arguments is of primitive type and is null, null also 
will be returned
+ *                      without invoking the function.
  * @param returnNullable When false, indicating the invoked method will always 
return
  *                       non-null value.
  */
@@ -318,7 +326,9 @@ case class StaticInvoke(
  * @param arguments An optional list of expressions, whose evaluation will be 
passed to the
   *                 function.
  * @param propagateNull When true, and any of the arguments is null, null will 
be returned instead
- *                      of calling the function.
+ *                      of calling the function. Also note: when this is false 
but any of the
+ *                      arguments is of primitive type and is null, null also 
will be returned
+ *                      without invoking the function.
  * @param returnNullable When false, indicating the invoked method will always 
return
  *                       non-null value.
  */
@@ -452,7 +462,9 @@ object NewInstance {
  * @param cls The class to construct.
  * @param arguments A list of expression to use as arguments to the 
constructor.
  * @param propagateNull When true, if any of the arguments is null, then null 
will be returned
- *                      instead of trying to construct the object.
+ *                      instead of trying to construct the object. Also note: 
when this is false
+ *                      but any of the arguments is of primitive type and is 
null, null also will
+ *                      be returned without constructing the object.
  * @param dataType The type of object being constructed, as a Spark SQL 
datatype.  This allows you
  *                 to manually specify the type when the object in question is 
a valid internal
  *                 representation (i.e. ArrayData) instead of an object.
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
index 7cd010b..1b16896 100644
--- 
a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -120,5 +120,24 @@ public class JavaStrLen implements UnboundFunction {
 
   public static class JavaStrLenNoImpl extends JavaStrLenBase {
   }
+
+  // a null-safe version which returns 0 for null arguments
+  public static class JavaStrLenMagicNullSafe extends JavaStrLenBase {
+    public int invoke(UTF8String str) {
+      if (str == null) {
+        return 0;
+      }
+      return str.toString().length();
+    }
+  }
+
+  public static class JavaStrLenStaticMagicNullSafe extends JavaStrLenBase {
+    public static int invoke(UTF8String str) {
+      if (str == null) {
+        return 0;
+      }
+      return str.toString().length();
+    }
+  }
 }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index bd4dfe4..801aee5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -20,12 +20,14 @@ package org.apache.spark.sql.connector
 import java.util
 import java.util.Collections
 
-import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, 
JavaStrLen}
+import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, 
JavaLongAdd, JavaStrLen}
+import 
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
 import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, Row}
 import org.apache.spark.sql.catalyst.InternalRow
+import 
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, 
NO_CODEGEN}
 import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, 
Identifier, InMemoryCatalog, SupportsNamespaces}
 import org.apache.spark.sql.connector.catalog.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -213,6 +215,39 @@ class DataSourceV2FunctionSuite extends 
DatasourceV2SQLBase {
       .getMessage.contains("neither implement magic method nor override 
'produceResult'"))
   }
 
+  test("SPARK-35389: magic function should handle null arguments") {
+    
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
+    addFunction(Identifier.of(Array("ns"), "strlen"), new JavaStrLen(new 
JavaStrLenMagicNullSafe))
+    addFunction(Identifier.of(Array("ns"), "strlen2"),
+      new JavaStrLen(new JavaStrLenStaticMagicNullSafe))
+    Seq("strlen", "strlen2").foreach { name =>
+      checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as STRING))"), 
Row(0) :: Nil)
+    }
+  }
+
+  test("SPARK-35389: magic function should handle null primitive arguments") {
+    
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
+    addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new 
JavaLongAddMagic(false)))
+    addFunction(Identifier.of(Array("ns"), "static_add"),
+      new JavaLongAdd(new JavaLongAddMagic(false)))
+
+    Seq("add", "static_add").foreach { name =>
+      Seq(true, false).foreach { codegenEnabled =>
+        val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN
+
+        withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> 
codegenEnabled.toString,
+          SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) {
+
+          checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), 
42L)"), Row(null) :: Nil)
+          checkAnswer(sql(s"SELECT testcat.ns.$name(42L, CAST(NULL as 
BIGINT))"), Row(null) :: Nil)
+          checkAnswer(sql(s"SELECT testcat.ns.$name(42L, 58L)"), Row(100) :: 
Nil)
+          checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT), 
CAST(NULL as BIGINT))"),
+            Row(null) :: Nil)
+        }
+      }
+    }
+  }
+
   test("bad bound function (neither scalar nor aggregate)") {
     
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
 emptyProps)
     addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to