Repository: spark
Updated Branches:
  refs/heads/master e92c24d37 -> 4a22bce8f


[SPARK-8572] [SQL] Type coercion for ScalaUDFs

Implemented type coercion for udf arguments in Scala. The changes include-
* Add `with ExpectsInputTypes ` to `ScalaUDF` class.
* Pass down argument types info from `UDFRegistration` and `functions`.

With this patch, the example query in 
[SPARK-8572](https://issues.apache.org/jira/browse/SPARK-8572) no longer throws 
a type cast error at runtime.

Also added a unit test to `UDFSuite` in which a decimal type is passed to a udf 
that expects an int.

Author: Cheolsoo Park <cheols...@netflix.com>

Closes #7203 from piaozhexiu/SPARK-8572 and squashes the following commits:

2d0ed15 [Cheolsoo Park] Incorporate comments
dce1efd [Cheolsoo Park] Fix unit tests and update the codegen script
066deed [Cheolsoo Park] Type coercion for udf inputs


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4a22bce8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4a22bce8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4a22bce8

Branch: refs/heads/master
Commit: 4a22bce8fce30f86f364467a8ba51d2e744ff379
Parents: e92c24d
Author: Cheolsoo Park <cheols...@netflix.com>
Authored: Fri Jul 3 22:14:21 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Jul 3 22:14:21 2015 -0700

----------------------------------------------------------------------
 .../catalyst/analysis/HiveTypeCoercion.scala    |  2 +-
 .../sql/catalyst/expressions/ScalaUDF.scala     |  7 +-
 .../org/apache/spark/sql/UDFRegistration.scala  | 75 +++++++++++++-------
 .../apache/spark/sql/UserDefinedFunction.scala  |  7 +-
 .../scala/org/apache/spark/sql/functions.scala  | 38 ++++++----
 .../scala/org/apache/spark/sql/UDFSuite.scala   |  6 ++
 6 files changed, 93 insertions(+), 42 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 38eb832..84acc0e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -680,7 +680,7 @@ object HiveTypeCoercion {
       // Skip nodes who's children have not been resolved yet.
       case e if !e.childrenResolved => e
 
-      case e: ExpectsInputTypes =>
+      case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
         val children: Seq[Expression] = e.children.zip(e.inputTypes).map { 
case (in, expected) =>
           // If we cannot do the implicit cast, just use the original input.
           implicitCast(in, expected).getOrElse(in)

http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index caf021b..fc055c9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -24,8 +24,11 @@ import org.apache.spark.sql.types.DataType
  * User-defined function.
  * @param dataType  Return type of function.
  */
-case class ScalaUDF(function: AnyRef, dataType: DataType, children: 
Seq[Expression])
-  extends Expression {
+case class ScalaUDF(
+    function: AnyRef,
+    dataType: DataType,
+    children: Seq[Expression],
+    inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes 
{
 
   override def nullable: Boolean = true
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 03dc37a..d35d37d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.util.{List => JList, Map => JMap}
 
 import scala.reflect.runtime.universe.TypeTag
+import scala.util.Try
 
 import org.apache.spark.{Accumulator, Logging}
 import org.apache.spark.api.python.PythonBroadcast
@@ -30,7 +31,6 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, 
ScalaUDF}
 import org.apache.spark.sql.execution.PythonUDF
 import org.apache.spark.sql.types.DataType
 
-
 /**
  * Functions for registering user-defined functions. Use [[SQLContext.udf]] to 
access this.
  *
@@ -87,6 +87,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     (0 to 22).map { x =>
       val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
       val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: 
TypeTag")(_ + ", " + _)
+      val inputTypes = (1 to x).foldRight("Nil")((i, s) => 
{s"ScalaReflection.schemaFor[A$i].dataType :: $s"})
       println(s"""
         /**
          * Register a Scala closure of ${x} arguments as user-defined function 
(UDF).
@@ -95,7 +96,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
          */
         def register[$typeTags](name: String, func: Function$x[$types]): 
UserDefinedFunction = {
           val dataType = ScalaReflection.schemaFor[RT].dataType
-          def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+          val inputTypes = Try($inputTypes).getOrElse(Nil)
+          def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, 
inputTypes)
           functionRegistry.registerFunction(name, builder)
           UserDefinedFunction(func, dataType)
         }""")
@@ -126,7 +128,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag](name: String, func: Function0[RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -138,7 +141,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, 
RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -150,7 +154,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: 
Function2[A1, A2, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -162,7 +167,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: 
String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -174,7 +180,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -186,7 +193,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -198,7 +206,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, 
A4, A5, A6, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -210,7 +219,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: 
Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -222,7 +232,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, 
func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -234,7 +245,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: 
String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -246,7 +258,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, 
A10, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -258,7 +271,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, 
A7, A8, A9, A10, A11, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -270,7 +284,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, 
A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -282,7 +297,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: 
Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -294,7 +310,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, 
func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, 
RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -306,7 +323,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: 
TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, 
A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -318,7 +336,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, 
A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -330,7 +349,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, 
A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -342,7 +362,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, 
A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, 
RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -354,7 +375,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: 
Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, 
A16, A17, A18, A19, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -366,7 +388,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: 
String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, 
A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -378,7 +401,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: 
TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, 
A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): 
UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: 
ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }
@@ -390,7 +414,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
    */
   def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: 
TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: 
TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, 
A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: 
TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, 
A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, 
RT]): UserDefinedFunction = {
     val dataType = ScalaReflection.schemaFor[RT].dataType
-    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e)
+    val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: 
ScalaReflection.schemaFor[A21].dataType :: 
ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil)
+    def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
     UserDefinedFunction(func, dataType)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index 831eb7e..b14e00a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -41,10 +41,13 @@ import org.apache.spark.sql.types.DataType
  * @since 1.3.0
  */
 @Experimental
-case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) {
+case class UserDefinedFunction protected[sql] (
+    f: AnyRef,
+    dataType: DataType,
+    inputTypes: Seq[DataType] = Nil) {
 
   def apply(exprs: Column*): Column = {
-    Column(ScalaUDF(f, dataType, exprs.map(_.expr)))
+    Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes))
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4b70dc5..d261baf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{TypeTag, typeTag}
+import scala.util.Try
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.sql.catalyst.ScalaReflection
@@ -1584,6 +1585,7 @@ object functions {
   (0 to 10).map { x =>
     val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"})
     val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: 
TypeTag")(_ + ", " + _)
+    val inputTypes = (1 to x).foldRight("Nil")((i, s) => 
{s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"})
     println(s"""
     /**
      * Defines a user-defined function of ${x} arguments as user-defined 
function (UDF).
@@ -1593,7 +1595,8 @@ object functions {
      * @since 1.3.0
      */
     def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
-      UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+      val inputTypes = Try($inputTypes).getOrElse(Nil)
+      UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
     }""")
   }
 
@@ -1625,7 +1628,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1636,7 +1640,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction 
= {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1647,7 +1652,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): 
UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1658,7 +1664,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, 
A2, A3, RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1669,7 +1676,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: 
Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1680,7 +1688,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1691,7 +1700,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): 
UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: 
ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1702,7 +1712,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, 
RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: 
ScalaReflection.schemaFor(typeTag[A6]).dataType :: 
ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1713,7 +1724,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, 
A5, A6, A7, A8, RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: 
ScalaReflection.schemaFor(typeTag[A6]).dataType :: 
ScalaReflection.schemaFor(typeTag[A7]).dataType :: 
ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1724,7 +1736,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, 
A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: 
ScalaReflection.schemaFor(typeTag[A6]).dataType :: 
ScalaReflection.schemaFor(typeTag[A7]).dataType :: 
ScalaReflection.schemaFor(typeTag[A8]).dataType :: 
ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   /**
@@ -1735,7 +1748,8 @@ object functions {
    * @since 1.3.0
    */
   def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: 
TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: 
Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = 
{
-    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
+    val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: 
ScalaReflection.schemaFor(typeTag[A2]).dataType :: 
ScalaReflection.schemaFor(typeTag[A3]).dataType :: 
ScalaReflection.schemaFor(typeTag[A4]).dataType :: 
ScalaReflection.schemaFor(typeTag[A5]).dataType :: 
ScalaReflection.schemaFor(typeTag[A6]).dataType :: 
ScalaReflection.schemaFor(typeTag[A7]).dataType :: 
ScalaReflection.schemaFor(typeTag[A8]).dataType :: 
ScalaReflection.schemaFor(typeTag[A9]).dataType :: 
ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil)
+    UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, 
inputTypes)
   }
 
   
//////////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/4a22bce8/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 8e5da3a..c1516b4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -166,4 +166,10 @@ class UDFSuite extends QueryTest {
     // 1 + 1 is constant folded causing a transformation.
     assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === 
Row(2, 2))
   }
+
+  test("type coercion for udf inputs") {
+    ctx.udf.register("intExpected", (x: Int) => x)
+    // pass a decimal to intExpected.
+    assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to