Repository: spark
Updated Branches:
  refs/heads/branch-2.4 9ed2e4204 -> df60d9f34


http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 697757f..eb956c4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -73,19 +73,24 @@ case class UserDefinedFunction protected[sql] (
    */
   @scala.annotation.varargs
   def apply(exprs: Column*): Column = {
-    if (inputTypes.isDefined && nullableTypes.isDefined) {
-      require(inputTypes.get.length == nullableTypes.get.length)
+    // TODO: make sure this class is only instantiated through 
`SparkUserDefinedFunction.create()`
+    // and `nullableTypes` is always set.
+    if (nullableTypes.isEmpty) {
+      nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f))
+    }
+    if (inputTypes.isDefined) {
+      assert(inputTypes.get.length == nullableTypes.get.length)
     }
 
     Column(ScalaUDF(
       f,
       dataType,
       exprs.map(_.expr),
+      nullableTypes.get,
       inputTypes.getOrElse(Nil),
       udfName = _nameOption,
       nullable = _nullable,
-      udfDeterministic = _deterministic,
-      nullableTypes = nullableTypes.getOrElse(Nil)))
+      udfDeterministic = _deterministic))
   }
 
   private def copyAll(): UserDefinedFunction = {
@@ -146,9 +151,14 @@ private[sql] object SparkUserDefinedFunction {
   def create(
       f: AnyRef,
       dataType: DataType,
-      inputSchemas: Option[Seq[ScalaReflection.Schema]]): UserDefinedFunction 
= {
-    val udf = new UserDefinedFunction(f, dataType, 
inputSchemas.map(_.map(_.dataType)))
-    udf.nullableTypes = inputSchemas.map(_.map(_.nullable))
+      inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction 
= {
+    val inputTypes = if (inputSchemas.contains(None)) {
+      None
+    } else {
+      Some(inputSchemas.map(_.get.dataType))
+    }
+    val udf = new UserDefinedFunction(f, dataType, inputTypes)
+    udf.nullableTypes = 
Some(inputSchemas.map(_.map(_.nullable).getOrElse(true)))
     udf
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 10b67d7..6a43ce1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3819,7 +3819,7 @@ object functions {
   (0 to 10).foreach { x =>
     val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
     val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: 
TypeTag")(_ + ", " + _)
-    val inputSchemas = (1 to x).foldRight("Nil")((i, s) => 
{s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"})
+    val inputSchemas = (1 to x).foldRight("Nil")((i, s) => 
{s"Try(ScalaReflection.schemaFor(typeTag[A$i])).toOption :: $s"})
     println(s"""
       |/**
       | * Defines a Scala closure of $x arguments as user-defined function 
(UDF).
@@ -3832,7 +3832,7 @@ object functions {
       | */
       |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
       |  val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-      |  val inputSchemas = Try($inputTypes).toOption
+      |  val inputSchemas = $inputSchemas
       |  val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
       |  if (nullable) udf else udf.asNonNullable()
       |}""".stripMargin)
@@ -3856,7 +3856,7 @@ object functions {
       | */
       |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): 
UserDefinedFunction = {
       |  val func = f$anyCast.call($anyParams)
-      |  SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = 
None)
+      |  SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = 
Seq.fill($i)(None))
       |}""".stripMargin)
   }
 
@@ -3877,7 +3877,7 @@ object functions {
    */
   def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(Nil).toOption
+    val inputSchemas = Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3893,7 +3893,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction 
= {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3909,7 +3909,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): 
UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3925,7 +3925,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, 
A2, A3, RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3941,7 +3941,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: 
Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3957,7 +3957,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3973,7 +3973,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): 
UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: 
ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -3989,7 +3989,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, 
RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: 
ScalaReflection.schemaFor(typeTag[A6]) :: 
ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -4005,7 +4005,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, 
A5, A6, A7, A8, RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: 
ScalaReflection.schemaFor(typeTag[A6]) :: 
ScalaReflection.schemaFor(typeTag[A7]) :: 
ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -4021,7 +4021,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, 
A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: 
ScalaReflection.schemaFor(typeTag[A6]) :: 
ScalaReflection.schemaFor(typeTag[A7]) :: 
ScalaReflection.schemaFor(typeTag[A8]) :: 
ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -4037,7 +4037,7 @@ object functions {
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: 
Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = 
{
     val ScalaReflection.Schema(dataType, nullable) = 
ScalaReflection.schemaFor[RT]
-    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: 
ScalaReflection.schemaFor(typeTag[A2]) :: 
ScalaReflection.schemaFor(typeTag[A3]) :: 
ScalaReflection.schemaFor(typeTag[A4]) :: 
ScalaReflection.schemaFor(typeTag[A5]) :: 
ScalaReflection.schemaFor(typeTag[A6]) :: 
ScalaReflection.schemaFor(typeTag[A7]) :: 
ScalaReflection.schemaFor(typeTag[A8]) :: 
ScalaReflection.schemaFor(typeTag[A9]) :: 
ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption
+    val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: 
Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil
     val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas)
     if (nullable) udf else udf.asNonNullable()
   }
@@ -4057,7 +4057,7 @@ object functions {
    */
   def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = {
     val func = f.asInstanceOf[UDF0[Any]].call()
-    SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = 
None)
+    SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = 
Seq.fill(0)(None))
   }
 
   /**
@@ -4071,7 +4071,7 @@ object functions {
    */
   def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = {
     val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(1)(None))
   }
 
   /**
@@ -4085,7 +4085,7 @@ object functions {
    */
   def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = {
     val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(2)(None))
   }
 
   /**
@@ -4099,7 +4099,7 @@ object functions {
    */
   def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = {
     val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, 
_: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(3)(None))
   }
 
   /**
@@ -4113,7 +4113,7 @@ object functions {
    */
   def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = 
{
     val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: 
Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(4)(None))
   }
 
   /**
@@ -4127,7 +4127,7 @@ object functions {
    */
   def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, 
_: Any, _: Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(5)(None))
   }
 
   /**
@@ -4141,7 +4141,7 @@ object functions {
    */
   def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: 
Any, _: Any, _: Any, _: Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(6)(None))
   }
 
   /**
@@ -4155,7 +4155,7 @@ object functions {
    */
   def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(7)(None))
   }
 
   /**
@@ -4169,7 +4169,7 @@ object functions {
    */
   def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, 
Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(8)(None))
   }
 
   /**
@@ -4183,7 +4183,7 @@ object functions {
    */
   def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, 
_: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(9)(None))
   }
 
   /**
@@ -4197,7 +4197,7 @@ object functions {
    */
   def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): 
UserDefinedFunction = {
     val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, 
Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: 
Any, _: Any, _: Any)
-    SparkUserDefinedFunction.create(func, returnType, inputSchemas = None)
+    SparkUserDefinedFunction.create(func, returnType, inputSchemas = 
Seq.fill(10)(None))
   }
 
   // scalastyle:on parameter.number
@@ -4216,7 +4216,9 @@ object functions {
    * @since 2.0.0
    */
   def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = {
-    SparkUserDefinedFunction.create(f, dataType, inputSchemas = None)
+    // TODO: should call SparkUserDefinedFunction.create() instead but 
inputSchemas is currently
+    // unavailable. We may need to create type-safe overloaded versions of 
udf() methods.
+    new UserDefinedFunction(f, dataType, inputTypes = None)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 30dca94..f8ed21b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -393,4 +393,28 @@ class UDFSuite extends QueryTest with SharedSQLContext {
       checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null)))
     }
   }
+
+  test("SPARK-25044 Verify null input handling for primitive types - with 
udf()") {
+    val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0))
+    val df = spark.range(0, 3).toDF("a")
+      .withColumn("b", udf1($"a", lit(null)))
+      .withColumn("c", udf1(lit(null), $"a"))
+
+    checkAnswer(
+      df,
+      Seq(
+        Row(0, 1, null),
+        Row(1, 3, null),
+        Row(2, 5, null)))
+  }
+
+  test("SPARK-25044 Verify null input handling for primitive types - with 
udf.register") {
+    withTable("t") {
+      Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), 
null))
+        .toDF("a", "b", "c").write.format("json").saveAsTable("t")
+      spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c)
+      val df = spark.sql("SELECT f(a, b, c) FROM t")
+      checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to