Repository: spark Updated Branches: refs/heads/branch-2.4 9ed2e4204 -> df60d9f34
http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 697757f..eb956c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -73,19 +73,24 @@ case class UserDefinedFunction protected[sql] ( */ @scala.annotation.varargs def apply(exprs: Column*): Column = { - if (inputTypes.isDefined && nullableTypes.isDefined) { - require(inputTypes.get.length == nullableTypes.get.length) + // TODO: make sure this class is only instantiated through `SparkUserDefinedFunction.create()` + // and `nullableTypes` is always set. + if (nullableTypes.isEmpty) { + nullableTypes = Some(ScalaReflection.getParameterTypeNullability(f)) + } + if (inputTypes.isDefined) { + assert(inputTypes.get.length == nullableTypes.get.length) } Column(ScalaUDF( f, dataType, exprs.map(_.expr), + nullableTypes.get, inputTypes.getOrElse(Nil), udfName = _nameOption, nullable = _nullable, - udfDeterministic = _deterministic, - nullableTypes = nullableTypes.getOrElse(Nil))) + udfDeterministic = _deterministic)) } private def copyAll(): UserDefinedFunction = { @@ -146,9 +151,14 @@ private[sql] object SparkUserDefinedFunction { def create( f: AnyRef, dataType: DataType, - inputSchemas: Option[Seq[ScalaReflection.Schema]]): UserDefinedFunction = { - val udf = new UserDefinedFunction(f, dataType, inputSchemas.map(_.map(_.dataType))) - udf.nullableTypes = inputSchemas.map(_.map(_.nullable)) + inputSchemas: Seq[Option[ScalaReflection.Schema]]): UserDefinedFunction = { + val inputTypes = if (inputSchemas.contains(None)) { + None + } else { + Some(inputSchemas.map(_.get.dataType)) + } + val udf = new UserDefinedFunction(f, dataType, inputTypes) + udf.nullableTypes = Some(inputSchemas.map(_.map(_.nullable).getOrElse(true))) udf } } http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7..6a43ce1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3819,7 +3819,7 @@ object functions { (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) - val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]) :: $s"}) + val inputSchemas = (1 to x).foldRight("Nil")((i, s) => {s"Try(ScalaReflection.schemaFor(typeTag[A$i])).toOption :: $s"}) println(s""" |/** | * Defines a Scala closure of $x arguments as user-defined function (UDF). @@ -3832,7 +3832,7 @@ object functions { | */ |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - | val inputSchemas = Try($inputTypes).toOption + | val inputSchemas = $inputSchemas | val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) | if (nullable) udf else udf.asNonNullable() |}""".stripMargin) @@ -3856,7 +3856,7 @@ object functions { | */ |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { | val func = f$anyCast.call($anyParams) - | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = None) + | SparkUserDefinedFunction.create($funcCall, returnType, inputSchemas = Seq.fill($i)(None)) |}""".stripMargin) } @@ -3877,7 +3877,7 @@ object functions { */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(Nil).toOption + val inputSchemas = Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3893,7 +3893,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3909,7 +3909,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3925,7 +3925,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3941,7 +3941,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3957,7 +3957,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3973,7 +3973,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -3989,7 +3989,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4005,7 +4005,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4021,7 +4021,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4037,7 +4037,7 @@ object functions { */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1]) :: ScalaReflection.schemaFor(typeTag[A2]) :: ScalaReflection.schemaFor(typeTag[A3]) :: ScalaReflection.schemaFor(typeTag[A4]) :: ScalaReflection.schemaFor(typeTag[A5]) :: ScalaReflection.schemaFor(typeTag[A6]) :: ScalaReflection.schemaFor(typeTag[A7]) :: ScalaReflection.schemaFor(typeTag[A8]) :: ScalaReflection.schemaFor(typeTag[A9]) :: ScalaReflection.schemaFor(typeTag[A10]) :: Nil).toOption + val inputSchemas = Try(ScalaReflection.schemaFor(typeTag[A1])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A2])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A3])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A4])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A5])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A6])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A7])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A8])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A9])).toOption :: Try(ScalaReflection.schemaFor(typeTag[A10])).toOption :: Nil val udf = SparkUserDefinedFunction.create(f, dataType, inputSchemas) if (nullable) udf else udf.asNonNullable() } @@ -4057,7 +4057,7 @@ object functions { */ def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF0[Any]].call() - SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(() => func, returnType, inputSchemas = Seq.fill(0)(None)) } /** @@ -4071,7 +4071,7 @@ object functions { */ def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(1)(None)) } /** @@ -4085,7 +4085,7 @@ object functions { */ def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(2)(None)) } /** @@ -4099,7 +4099,7 @@ object functions { */ def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(3)(None)) } /** @@ -4113,7 +4113,7 @@ object functions { */ def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(4)(None)) } /** @@ -4127,7 +4127,7 @@ object functions { */ def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(5)(None)) } /** @@ -4141,7 +4141,7 @@ object functions { */ def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(6)(None)) } /** @@ -4155,7 +4155,7 @@ object functions { */ def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(7)(None)) } /** @@ -4169,7 +4169,7 @@ object functions { */ def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(8)(None)) } /** @@ -4183,7 +4183,7 @@ object functions { */ def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(9)(None)) } /** @@ -4197,7 +4197,7 @@ object functions { */ def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) - SparkUserDefinedFunction.create(func, returnType, inputSchemas = None) + SparkUserDefinedFunction.create(func, returnType, inputSchemas = Seq.fill(10)(None)) } // scalastyle:on parameter.number @@ -4216,7 +4216,9 @@ object functions { * @since 2.0.0 */ def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { - SparkUserDefinedFunction.create(f, dataType, inputSchemas = None) + // TODO: should call SparkUserDefinedFunction.create() instead but inputSchemas is currently + // unavailable. We may need to create type-safe overloaded versions of udf() methods. + new UserDefinedFunction(f, dataType, inputTypes = None) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/df60d9f3/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 30dca94..f8ed21b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -393,4 +393,28 @@ class UDFSuite extends QueryTest with SharedSQLContext { checkAnswer(df, Seq(Row("12"), Row("24"), Row("3null"), Row(null))) } } + + test("SPARK-25044 Verify null input handling for primitive types - with udf()") { + val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0)) + val df = spark.range(0, 3).toDF("a") + .withColumn("b", udf1($"a", lit(null))) + .withColumn("c", udf1(lit(null), $"a")) + + checkAnswer( + df, + Seq( + Row(0, 1, null), + Row(1, 3, null), + Row(2, 5, null))) + } + + test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { + withTable("t") { + Seq((null, new Integer(1), "x"), ("M", null, "y"), ("N", new Integer(3), null)) + .toDF("a", "b", "c").write.format("json").saveAsTable("t") + spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) + val df = spark.sql("SELECT f(a, b, c) FROM t") + checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org