This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f502e20 [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability f502e20 is described below commit f502e209f49d3d76f947d1a8ba38c8c8a86e0bef Author: Wenchen Fan <wenc...@databricks.com> AuthorDate: Wed Feb 13 14:07:03 2019 +0900 [SPARK-26798][SQL] HandleNullInputsForUDF should trust nullability ## What changes were proposed in this pull request? There is a very old TODO in `HandleNullInputsForUDF`, saying that we can skip the null check if input is not nullable. We leverage the nullability info at many places, we can trust it here too. ## How was this patch tested? re-enable an ignored test Closes #23712 from cloud-fan/minor. Lead-authored-by: Wenchen Fan <wenc...@databricks.com> Co-authored-by: Xiao Li <gatorsm...@gmail.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../spark/sql/catalyst/analysis/Analyzer.scala | 47 ++++++++++++--------- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 9 ++-- .../sql/catalyst/analysis/AnalysisSuite.scala | 36 +++++++++------- .../sql/catalyst/expressions/ScalaUDFSuite.scala | 6 +-- .../spark/sql/catalyst/trees/TreeNodeSuite.scala | 2 +- .../org/apache/spark/sql/UDFRegistration.scala | 48 +++++++++++----------- .../datasources/FileFormatDataWriter.scala | 2 +- .../sql/expressions/UserDefinedFunction.scala | 8 ++-- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 37 ++++++++++++----- 9 files changed, 114 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a84bb76..793c337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2147,27 +2147,36 @@ class Analyzer( case p => p transformExpressionsUp { - case udf @ ScalaUDF(_, _, inputs, inputsNullSafe, _, _, _, _) - if inputsNullSafe.contains(false) => + case udf @ ScalaUDF(_, _, inputs, inputPrimitives, _, _, _, _) + if inputPrimitives.contains(true) => // Otherwise, add special handling of null for fields that can't accept null. // The result of operations like this, when passed null, is generally to return null. - assert(inputsNullSafe.length == inputs.length) - - // TODO: skip null handling for not-nullable primitive inputs after we can completely - // trust the `nullable` information. - val inputsNullCheck = inputsNullSafe.zip(inputs) - .filter { case (nullSafe, _) => !nullSafe } - .map { case (_, expr) => IsNull(expr) } - .reduceLeftOption[Expression]((e1, e2) => Or(e1, e2)) - // Once we add an `If` check above the udf, it is safe to mark those checked inputs - // as null-safe (i.e., set `inputsNullSafe` all `true`), because the null-returning - // branch of `If` will be called if any of these checked inputs is null. Thus we can - // prevent this rule from being applied repeatedly. - val newInputsNullSafe = inputsNullSafe.map(_ => true) - inputsNullCheck - .map(If(_, Literal.create(null, udf.dataType), - udf.copy(inputsNullSafe = newInputsNullSafe))) - .getOrElse(udf) + assert(inputPrimitives.length == inputs.length) + + val inputPrimitivesPair = inputPrimitives.zip(inputs) + val inputNullCheck = inputPrimitivesPair.collect { + case (isPrimitive, input) if isPrimitive && input.nullable => + IsNull(input) + }.reduceLeftOption[Expression](Or) + + if (inputNullCheck.isDefined) { + // Once we add an `If` check above the udf, it is safe to mark those checked inputs + // as null-safe (i.e., wrap with `KnownNotNull`), because the null-returning + // branch of `If` will be called if any of these checked inputs is null. Thus we can + // prevent this rule from being applied repeatedly. + val newInputs = inputPrimitivesPair.map { + case (isPrimitive, input) => + if (isPrimitive && input.nullable) { + KnownNotNull(input) + } else { + input + } + } + val newUDF = udf.copy(children = newInputs) + If(inputNullCheck.get, Literal.create(null, udf.dataType), newUDF) + } else { + udf + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index c9e0a2e..eb45f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -31,9 +31,10 @@ import org.apache.spark.sql.types.{AbstractDataType, DataType} * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. - * @param inputsNullSafe Whether the inputs are of non-primitive types or not nullable. Null values - * of Scala primitive types will be converted to the type's default value and - * lead to wrong results, thus need special handling before calling the UDF. + * @param inputPrimitives The analyzer should be aware of Scala primitive types so as to make the + * UDF return null if there is any null input value of these types. On the + * other hand, Java UDFs can only have boxed types, thus this parameter will + * always be all false. * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do * not want to perform coercion, simply use "Nil". Note that it would've been * better to use Option of Seq[DataType] so we can use "None" as the case for no @@ -47,7 +48,7 @@ case class ScalaUDF( function: AnyRef, dataType: DataType, children: Seq[Expression], - inputsNullSafe: Seq[Boolean], + inputPrimitives: Seq[Boolean], inputTypes: Seq[AbstractDataType] = Nil, udfName: Option[String] = None, nullable: Boolean = true, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 9829484..8038733 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -303,51 +303,57 @@ class AnalysisSuite extends AnalysisTest with Matchers { } test("SPARK-11725: correctly handle null inputs for ScalaUDF") { - val string = testRelation2.output(0) - val double = testRelation2.output(2) - val short = testRelation2.output(4) + val testRelation = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", DoubleType)(), + AttributeReference("c", ShortType)(), + AttributeReference("d", DoubleType, nullable = false)()) + + val string = testRelation.output(0) + val double = testRelation.output(1) + val short = testRelation.output(2) + val nonNullableDouble = testRelation.output(3) val nullResult = Literal.create(null, StringType) def checkUDF(udf: Expression, transformed: Expression): Unit = { checkAnalysis( - Project(Alias(udf, "")() :: Nil, testRelation2), - Project(Alias(transformed, "")() :: Nil, testRelation2) + Project(Alias(udf, "")() :: Nil, testRelation), + Project(Alias(transformed, "")() :: Nil, testRelation) ) } // non-primitive parameters do not need special null handling - val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, true :: Nil) + val udf1 = ScalaUDF((s: String) => "x", StringType, string :: Nil, false :: Nil) val expected1 = udf1 checkUDF(udf1, expected1) // only primitive parameter needs special null handling val udf2 = ScalaUDF((s: String, d: Double) => "x", StringType, string :: double :: Nil, - true :: false :: Nil) + false :: true :: Nil) val expected2 = - If(IsNull(double), nullResult, udf2.copy(inputsNullSafe = true :: true :: Nil)) + If(IsNull(double), nullResult, udf2.copy(children = string :: KnownNotNull(double) :: Nil)) checkUDF(udf2, expected2) // special null handling should apply to all primitive parameters val udf3 = ScalaUDF((s: Short, d: Double) => "x", StringType, short :: double :: Nil, - false :: false :: Nil) + true :: true :: Nil) val expected3 = If( IsNull(short) || IsNull(double), nullResult, - udf3.copy(inputsNullSafe = true :: true :: Nil)) + udf3.copy(children = KnownNotNull(short) :: KnownNotNull(double) :: Nil)) checkUDF(udf3, expected3) // we can skip special null handling for primitive parameters that are not nullable - // TODO: this is disabled for now as we can not completely trust `nullable`. val udf4 = ScalaUDF( (s: Short, d: Double) => "x", StringType, - short :: double.withNullability(false) :: Nil, - false :: false :: Nil) + short :: nonNullableDouble :: Nil, + true :: true :: Nil) val expected4 = If( IsNull(short), nullResult, - udf4.copy(inputsNullSafe = true :: true :: Nil)) - // checkUDF(udf4, expected4) + udf4.copy(children = KnownNotNull(short) :: nonNullableDouble :: Nil)) + checkUDF(udf4, expected4) } test("SPARK-24891 Fix HandleNullInputsForUDF rule") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 467cfd5..df92fa3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -29,7 +29,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil, true :: Nil) checkEvaluation(intUdf, 2) - val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil) + val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil) checkEvaluation(stringUdf, "ax") } @@ -38,7 +38,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { (s: String) => s.toLowerCase(Locale.ROOT), StringType, Literal.create(null, StringType) :: Nil, - true :: Nil) + false :: Nil) val e1 = intercept[SparkException](udf.eval()) assert(e1.getMessage.contains("Failed to execute user defined function")) @@ -51,7 +51,7 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext - ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, true :: Nil).genCode(ctx) + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil, false :: Nil).genCode(ctx) assert(ctx.inlinedMutableStates.isEmpty) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 64aa1ee..cb911d7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -564,7 +564,7 @@ class TreeNodeSuite extends SparkFunSuite { } test("toJSON should not throws java.lang.StackOverflowError") { - val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), true :: Nil) + val udf = ScalaUDF(SelfReferenceUDF(), BooleanType, Seq("col1".attr), false :: Nil) // Should not throw java.lang.StackOverflowError udf.toJSON } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index fe5d1af..83425de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -151,7 +151,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { | val func = f$anyCast.call($anyParams) | def builder(e: Seq[Expression]) = if (e.length == $i) { - | ScalaUDF($funcCall, returnType, e, e.map(_ => true), udfName = Some(name)) + | ScalaUDF($funcCall, returnType, e, e.map(_ => false), udfName = Some(name)) | } else { | throw new AnalysisException("Invalid number of arguments for function " + name + | ". Expected: $i; Found: " + e.length) @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF0[_], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF0[Any]].call() def builder(e: Seq[Expression]) = if (e.length == 0) { - ScalaUDF(() => func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(() => func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 0; Found: " + e.length) @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) def builder(e: Seq[Expression]) = if (e.length == 1) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 1; Found: " + e.length) @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 2) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 2; Found: " + e.length) @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 3) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 3; Found: " + e.length) @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 4) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 4; Found: " + e.length) @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 5) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 5; Found: " + e.length) @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 6) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 6; Found: " + e.length) @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 7) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 7; Found: " + e.length) @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 8) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 8; Found: " + e.length) @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 9) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 9; Found: " + e.length) @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 10) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 10; Found: " + e.length) @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 11) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 11; Found: " + e.length) @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 12) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 12; Found: " + e.length) @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 13) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 13; Found: " + e.length) @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 14) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 14; Found: " + e.length) @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 15) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 15; Found: " + e.length) @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 16) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 16; Found: " + e.length) @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 17) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 17; Found: " + e.length) @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 18) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 18; Found: " + e.length) @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 19) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 19; Found: " + e.length) @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 20) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 20; Found: " + e.length) @@ -1034,7 +1034,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 21) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 21; Found: " + e.length) @@ -1049,7 +1049,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { val func = f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) def builder(e: Seq[Expression]) = if (e.length == 22) { - ScalaUDF(func, returnType, e, e.map(_ => true), udfName = Some(name)) + ScalaUDF(func, returnType, e, e.map(_ => false), udfName = Some(name)) } else { throw new AnalysisException("Invalid number of arguments for function " + name + ". Expected: 22; Found: " + e.length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index b6b0d7a..2595cc6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -181,7 +181,7 @@ class DynamicPartitionDataWriter( ExternalCatalogUtils.getPartitionPathString _, StringType, Seq(Literal(c.name), Cast(c, StringType, Option(description.timeZoneId))), - Seq(true, true)) + Seq(false, false)) if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 4d8e1c5..0c956ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -100,14 +100,16 @@ private[sql] case class SparkUserDefinedFunction( private[sql] def createScalaUDF(exprs: Seq[Expression]): ScalaUDF = { // It's possible that some of the inputs don't have a specific type(e.g. `Any`), skip type - // check and null check for them. + // check. val inputTypes = inputSchemas.map(_.map(_.dataType).getOrElse(AnyDataType)) - val inputsNullSafe = inputSchemas.map(_.map(_.nullable).getOrElse(true)) + // `ScalaReflection.Schema.nullable` is false iff the type is primitive. Also `Any` is not + // primitive. + val inputsPrimitive = inputSchemas.map(_.map(!_.nullable).getOrElse(false)) ScalaUDF( f, dataType, exprs, - inputsNullSafe, + inputsPrimitive, inputTypes, udfName = name, nullable = nullable, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 5ac2093..e515800 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -400,17 +400,23 @@ class UDFSuite extends QueryTest with SharedSQLContext { } test("SPARK-25044 Verify null input handling for primitive types - with udf()") { - val udf1 = udf((x: Long, y: Any) => x * 2 + (if (y == null) 1 else 0)) - val df = spark.range(0, 3).toDF("a") - .withColumn("b", udf1($"a", lit(null))) - .withColumn("c", udf1(lit(null), $"a")) - - checkAnswer( - df, - Seq( - Row(0, 1, null), - Row(1, 3, null), - Row(2, 5, null))) + val input = Seq( + (null, Integer.valueOf(1), "x"), + ("M", null, "y"), + ("N", Integer.valueOf(3), null)).toDF("a", "b", "c") + + val udf1 = udf((a: String, b: Int, c: Any) => a + b + c) + val df = input.select(udf1('a, 'b, 'c)) + checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) + + // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed. + val udf2 = udf(new UDF3[String, Integer, Object, String] { + override def call(t1: String, t2: Integer, t3: Object): String = { + t1 + t2 + t3 + } + }, StringType) + val df2 = input.select(udf2('a, 'b, 'c)) + checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null"))) } test("SPARK-25044 Verify null input handling for primitive types - with udf.register") { @@ -420,6 +426,15 @@ class UDFSuite extends QueryTest with SharedSQLContext { spark.udf.register("f", (a: String, b: Int, c: Any) => a + b + c) val df = spark.sql("SELECT f(a, b, c) FROM t") checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null"))) + + // test Java UDF. Java UDF can't have primitive inputs, as it's generic typed. + spark.udf.register("f2", new UDF3[String, Integer, Object, String] { + override def call(t1: String, t2: Integer, t3: Object): String = { + t1 + t2 + t3 + } + }, StringType) + val df2 = spark.sql("SELECT f2(a, b, c) FROM t") + checkAnswer(df2, Seq(Row("null1x"), Row("Mnully"), Row("N3null"))) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org