Repository: spark Updated Branches: refs/heads/master 8c67aa7f0 -> 3aa4e464a
[SPARK-20416][SQL] Print UDF names in EXPLAIN ## What changes were proposed in this pull request? This pr added `withName` in `UserDefinedFunction` for printing UDF names in EXPLAIN ## How was this patch tested? Added tests in `UDFSuite`. Author: Takeshi Yamamuro <yamam...@apache.org> Closes #17712 from maropu/SPARK-20416. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3aa4e464 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3aa4e464 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3aa4e464 Branch: refs/heads/master Commit: 3aa4e464a8c81994c6b7f76d445640da719af6ed Parents: 8c67aa7 Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Thu May 11 09:49:05 2017 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Thu May 11 09:49:05 2017 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/Bucketizer.scala | 2 +- .../org/apache/spark/sql/UDFRegistration.scala | 50 ++++++++++---------- .../sql/expressions/UserDefinedFunction.scala | 13 +++++ .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +++-- 4 files changed, 46 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/3aa4e464/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index bb8f2a3..46b512f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -114,7 +114,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String val bucketizer: UserDefinedFunction = udf { (feature: Double) => Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) - } + }.withName("bucketizer") val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) http://git-wip-us.apache.org/repos/asf/spark/blob/3aa4e464/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 5fd7123..1bceac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try($inputTypes).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) }""") } @@ -147,7 +147,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -160,7 +160,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -173,7 +173,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -199,7 +199,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -212,7 +212,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -225,7 +225,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -238,7 +238,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -251,7 +251,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -264,7 +264,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -277,7 +277,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -290,7 +290,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -303,7 +303,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -316,7 +316,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -329,7 +329,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -355,7 +355,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -368,7 +368,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -381,7 +381,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -394,7 +394,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -407,7 +407,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -420,7 +420,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } /** @@ -433,7 +433,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable) functionRegistry.registerFunction(name, builder) - UserDefinedFunction(func, dataType, inputTypes).withNullability(nullable) + UserDefinedFunction(func, dataType, inputTypes).withName(name).withNullability(nullable) } ////////////////////////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/3aa4e464/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 5a0f488..0c5f1b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -47,6 +47,7 @@ case class UserDefinedFunction protected[sql] ( dataType: DataType, inputTypes: Option[Seq[DataType]]) { + private var _nameOption: Option[String] = None private var _nullable: Boolean = true /** @@ -67,16 +68,28 @@ case class UserDefinedFunction protected[sql] ( dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil), + udfName = _nameOption, nullable = _nullable)) } private def copyAll(): UserDefinedFunction = { val udf = copy() + udf._nameOption = _nameOption udf._nullable = _nullable udf } /** + * Updates UserDefinedFunction with a given name. + * + * @since 2.3.0 + */ + def withName(name: String): this.type = { + this._nameOption = Option(name) + this + } + + /** * Updates UserDefinedFunction with a given nullability. * * @since 2.3.0 http://git-wip-us.apache.org/repos/asf/spark/blob/3aa4e464/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 6f8723a..b4f744b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -263,10 +263,12 @@ class UDFSuite extends QueryTest with SharedSQLContext { val sparkPlan = spark.sessionState.executePlan(explain).executedPlan sparkPlan.executeCollect().map(_.getString(0).trim).headOption.getOrElse("") } - val udf1 = "myUdf1" - val udf2 = "myUdf2" - spark.udf.register(udf1, (n: Int) => { n + 1 }) - spark.udf.register(udf2, (n: Int) => { n * 1 }) - assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1(UDF:$udf2(1))")) + val udf1Name = "myUdf1" + val udf2Name = "myUdf2" + val udf1 = spark.udf.register(udf1Name, (n: Int) => n + 1) + val udf2 = spark.udf.register(udf2Name, (n: Int) => n * 1) + assert(explainStr(sql("SELECT myUdf1(myUdf2(1))")).contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) + assert(explainStr(spark.range(1).select(udf1(udf2(functions.lit(1))))) + .contains(s"UDF:$udf1Name(UDF:$udf2Name(1))")) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org