Repository: spark
Updated Branches:
  refs/heads/master 89fd9bd06 -> 03ba56d78


[SPARK-11716][SQL] UDFRegistration just drops the input type when re-creating 
the UserDefinedFunction

https://issues.apache.org/jira/browse/SPARK-11716

This is one is #9739 and a regression test. When commit it, please make sure 
the author is jbonofre.

You can find the original PR at https://github.com/apache/spark/pull/9739

closes #9739

Author: Jean-Baptiste Onofré <jbono...@apache.org>
Author: Yin Huai <yh...@databricks.com>

Closes #9868 from yhuai/SPARK-11716.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/03ba56d7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/03ba56d7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/03ba56d7

Branch: refs/heads/master
Commit: 03ba56d78f50747710d01c27d409ba2be42ae557
Parents: 89fd9bd
Author: Jean-Baptiste Onofré <jbono...@apache.org>
Authored: Fri Nov 20 14:45:40 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Nov 20 14:45:40 2015 -0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/UDFRegistration.scala  | 48 ++++++++++----------
 .../scala/org/apache/spark/sql/UDFSuite.scala   | 15 ++++++
 2 files changed, 39 insertions(+), 24 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/03ba56d7/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index fc4d093..051694c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -88,7 +88,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
           val inputTypes = Try($inputTypes).getOrElse(Nil)
           def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, 
inputTypes)
           functionRegistry.registerFunction(name, builder)
-          UserDefinedFunction(func, dataType)
+          UserDefinedFunction(func, dataType, inputTypes)
         }""")
     }
 
@@ -120,7 +120,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -133,7 +133,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -146,7 +146,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -159,7 +159,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -172,7 +172,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -185,7 +185,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -211,7 +211,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -224,7 +224,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -237,7 +237,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -250,7 +250,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -263,7 +263,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -276,7 +276,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -289,7 +289,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -302,7 +302,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -315,7 +315,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -328,7 +328,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -341,7 +341,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -367,7 +367,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -380,7 +380,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -393,7 +393,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: 
ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   /**
@@ -406,7 +406,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) 
extends Logging {
     val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: 
ScalaReflection.schemaFor[A2].dataType :: 
ScalaReflection.schemaFor[A3].dataType :: 
ScalaReflection.schemaFor[A4].dataType :: 
ScalaReflection.schemaFor[A5].dataType :: 
ScalaReflection.schemaFor[A6].dataType :: 
ScalaReflection.schemaFor[A7].dataType :: 
ScalaReflection.schemaFor[A8].dataType :: 
ScalaReflection.schemaFor[A9].dataType :: 
ScalaReflection.schemaFor[A10].dataType :: 
ScalaReflection.schemaFor[A11].dataType :: 
ScalaReflection.schemaFor[A12].dataType :: 
ScalaReflection.schemaFor[A13].dataType :: 
ScalaReflection.schemaFor[A14].dataType :: 
ScalaReflection.schemaFor[A15].dataType :: 
ScalaReflection.schemaFor[A16].dataType :: 
ScalaReflection.schemaFor[A17].dataType :: 
ScalaReflection.schemaFor[A18].dataType :: 
ScalaReflection.schemaFor[A19].dataType :: 
ScalaReflection.schemaFor[A20].dataType :: 
ScalaReflection.schemaFor[A21].dataType :: 
ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil)
     def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes)
     functionRegistry.registerFunction(name, builder)
-    UserDefinedFunction(func, dataType)
+    UserDefinedFunction(func, dataType, inputTypes)
   }
 
   
//////////////////////////////////////////////////////////////////////////////////////////////

http://git-wip-us.apache.org/repos/asf/spark/blob/03ba56d7/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 9837fa6..fd73671 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -232,4 +232,19 @@ class UDFSuite extends QueryTest with SharedSQLContext {
            | (SELECT complexDataFunc(m, a, b) AS t FROM complexData) tmp
           """.stripMargin).toDF(), complexData.select("m", "a", "b"))
   }
+
+  test("SPARK-11716 UDFRegistration does not include the input data type in 
returned UDF") {
+    val myUDF = sqlContext.udf.register("testDataFunc", (n: Int, s: String) => 
{ (n, s.toInt) })
+
+    // Without the fix, this will fail because we fail to cast data type of b 
to string
+    // because myUDF does not know its input data type. With the fix, this 
query should not
+    // fail.
+    checkAnswer(
+      testData2.select(myUDF($"a", $"b").as("t")),
+      testData2.selectExpr("struct(a, b)"))
+
+    checkAnswer(
+      sql("SELECT tmp.t.* FROM (SELECT testDataFunc(a, b) AS t from testData2) 
tmp").toDF(),
+      testData2)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to